diff --git a/third_party/intel/backend/include/sycl_functions.h b/third_party/intel/backend/include/sycl_functions.h index 1c9c071957..c06af5acdd 100644 --- a/third_party/intel/backend/include/sycl_functions.h +++ b/third_party/intel/backend/include/sycl_functions.h @@ -143,6 +143,8 @@ inline std::optional isEnvValueBool(std::string str) { return std::nullopt; } +static constexpr int kBlockIOPitchSpecId = 123; + std::tuple create_module(ze_context_handle_t context, ze_device_handle_t device, uint8_t *binary_ptr, size_t binary_size, const char *build_flags, @@ -152,12 +154,25 @@ create_module(ze_context_handle_t context, ze_device_handle_t device, const ze_module_format_t format = is_spv ? ZE_MODULE_FORMAT_IL_SPIRV : ZE_MODULE_FORMAT_NATIVE; + + uint64_t pitchBytesMode = 0u; // TODO just for test, remove + + ze_module_constants_t specConsts{}; + uint32_t ids[] = {kBlockIOPitchSpecId}; + const void *values[] = {&pitchBytesMode}; + + specConsts.numConstants = 1; + specConsts.pConstantIds = ids; + specConsts.pConstantValues = values; + ze_module_desc_t module_description = {}; module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC; module_description.format = format; module_description.inputSize = static_cast(binary_size); module_description.pInputModule = binary_ptr; module_description.pBuildFlags = build_flags; + module_description.pConstants = &specConsts; + ze_module_build_log_handle_t buildlog; ze_module_handle_t module; auto error_no = diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 074ce14ca8..1d851a3071 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -39,6 +39,8 @@ static int __builtin_ctz(unsigned x) { namespace { +static constexpr int kBlockIOPitchSpecId = 123; + Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { auto tb = TritonLLVMOpBuilder(loc, rewriter); if (a && b) { @@ -338,6 +340,11 @@ struct LoadStoreConversionBase { triton::tools::getBoolEnv("TRITON_INTEL_PREDICATED"); }; +static Value emitGenericLoad(triton::LoadOp op, Value llPtr, Value llMask, + Value llOther, ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + const LoadStoreConversionBase &base); + struct BlockIOConversionBase : public LoadStoreConversionBase { explicit BlockIOConversionBase( const triton::intel::TargetInfo &targetInfo, @@ -1659,9 +1666,13 @@ struct LoadOpToBlockIOConversion std::swap(baseWidth, baseHeight); } // HW requires the pitch to be at least 64 bytes. + bool needRuntimePitchCheck = false; + if (auto pitchConst = mlir::triton::intel::getFoldedConstantValue(pitch)) { if ((*pitchConst * elemSizeInBits / 8) < 64) return failure(); + } else { + needRuntimePitchCheck = true; } baseWidth = b.trunc(i32_ty, baseWidth); @@ -1695,204 +1706,213 @@ struct LoadOpToBlockIOConversion << " bits)\n"; }); - ValueTable loadVals; - for (int outer = 0; outer < numRepOuter; ++outer) { - for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) { - for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) { - LLVM_DEBUG({ - llvm::dbgs() << "outer, rep, k: " << outer << ", " << rep << ", " - << k << "\n"; - }); + using ValueTable = std::map, Value>; + + auto buildBlockIOResult = [&]() -> Value { + ValueTable loadVals; + + for (int outer = 0; outer < numRepOuter; ++outer) { + for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) { + for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) { + + const int loadIdx = + (outer * numLoadPerOutRepCluster * + (numRepInner / numOperandsInnerDimPerLoad)) + + rep * (numRepInner / numOperandsInnerDimPerLoad) + + k / numOperandsInnerDimPerLoad; + + const auto offset = tileLayout.apply( + {{kOffset, 0}, {kIteration, 0}, {kLoad, loadIdx}}); + assert(offset.size() == 2); + + const auto layoutOffsetX = offset[dimInner].second; + const auto layoutOffsetY = offset[dimOuter].second; + + Value offsetX, offsetY; + switch (opIdx) { + case DpasEncodingAttr::OpIdx::OperandA: { + offsetY = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), + b.i32_val(layoutOffsetY)); + offsetX = b.i32_val(layoutOffsetX); + } break; + case DpasEncodingAttr::OpIdx::OperandB: { + offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), + b.i32_val(layoutOffsetX)); + offsetY = b.i32_val(layoutOffsetY); + } break; + case DpasEncodingAttr::OpIdx::OperandC: + llvm_unreachable("unexpected OpIdx::OperandC"); + } - const int loadIdx = (outer * numLoadPerOutRepCluster * - (numRepInner / numOperandsInnerDimPerLoad)) + - rep * (numRepInner / numOperandsInnerDimPerLoad) + - k / numOperandsInnerDimPerLoad; - LLVM_DEBUG(llvm::dbgs() << "loadIdx: " << loadIdx << "\n"); - - const auto offset = tileLayout.apply( - {{kOffset, 0}, {kIteration, 0}, {kLoad, loadIdx}}); - assert(offset.size() == 2); - - const auto layoutOffsetX = offset[dimInner].second; - const auto layoutOffsetY = offset[dimOuter].second; - LLVM_DEBUG({ - llvm::dbgs() << "x offset ll: " << layoutOffsetX << "\n"; - llvm::dbgs() << "y offset ll: " << layoutOffsetY << "\n"; - }); + offsetX = b.add(offsetX, offsetBaseX); + offsetY = b.add(offsetY, offsetBaseY); - Value offsetX, offsetY; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: { - LLVM_DEBUG({ - llvm::dbgs() << "x offset: " << k * repKStride << "\n"; - llvm::dbgs() << "y offset: " - << outer * repOuterStride + rep * repStride << "\n"; - }); - offsetY = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(layoutOffsetY)); - offsetX = b.i32_val(layoutOffsetX); - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - LLVM_DEBUG({ - llvm::dbgs() << "x offset: " - << outer * repOuterStride + rep * repStride << "\n"; - llvm::dbgs() << "y offset: " << k * repKStride << "\n"; - }); - offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(layoutOffsetX)); - offsetY = b.i32_val(layoutOffsetY); - } break; - case DpasEncodingAttr::OpIdx::OperandC: { - llvm_unreachable("unexpected OpIdx::OperandC"); - } break; - } + if (!memoryRowMajor) + std::swap(offsetX, offsetY); - offsetX = b.add(offsetX, offsetBaseX); - offsetY = b.add(offsetY, offsetBaseY); + if (isTransposeRequired) + offsetX = b.udiv(offsetX, b.i32_val(32 / originalElemBits)); - if (!memoryRowMajor) { - // Column major memory. We need to swap the X and Y because HW only - // support row major memory layout. - std::swap(offsetX, offsetY); - } - - if (isTransposeRequired) { - // adjust the block io parameter to align HW's limitations on - // transposing load. - offsetX = b.udiv(offsetX, b.i32_val(32 / originalElemBits)); - } + auto load2dOp = TritonGEN::Matrix2DBlockLoadOp::create( + rewriter, loc, load2DGenXType, + /*ptr*/ base, + /*base_width*/ b.mul(baseWidth, elemSizeInBytes), + /*base_height*/ baseHeight, + /*base_pitch*/ b.mul(pitch, elemSizeInBytes), + /*x*/ offsetX, + /*y*/ offsetY, + /*elem_size_in_bits*/ elemSizeInBits, + /*tile_width*/ tileWidth, + /*tile_height*/ tileHeight, + /*v_blocks*/ vBlocks, + /*transpose*/ isTransposeRequired, + /*vnni_transform*/ + (usePackedType && !isOperandA && !isTransposeRequired && + originalElemBits != 32)); + + if (failed(load2dOp.verify())) { + rewriter.eraseOp(load2dOp); + // Propagate failure to caller + llvm::report_fatal_error( + "Matrix2DBlockLoad verification failed in blockIO path"); + } - auto load2dOp = TritonGEN::Matrix2DBlockLoadOp::create( - rewriter, loc, load2DGenXType, - /*ptr*/ base, - /*base_width*/ b.mul(baseWidth, elemSizeInBytes), - /*base_height*/ baseHeight, - /*base_pitch*/ b.mul(pitch, elemSizeInBytes), - /*x*/ offsetX, - /*y*/ offsetY, - /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ tileWidth, - /*tile_height*/ tileHeight, - /*v_blocks*/ vBlocks, - /*transpose*/ isTransposeRequired, - /*vnni_transform*/ - (usePackedType && !isOperandA && !isTransposeRequired && - originalElemBits != 32)); - if (failed(load2dOp.verify())) { - // delete the op so that the verifier will not abort the pass - // pipeline later, as we can fail this path and try a different - // approach. - rewriter.eraseOp(load2dOp); - return failure(); - } - LLVM_DEBUG(llvm::dbgs() << "Generated load op: " << load2dOp << "\n"); - - unsigned packedRowNum = opIdx == DpasEncodingAttr::OpIdx::OperandA - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad; - unsigned packedColNum = opIdx == DpasEncodingAttr::OpIdx::OperandA - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad; - - // Decompose the return value to multiple operands. - unsigned packedColNumPerVBlock = packedColNum / vBlocks; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int row = 0; row < packedRowNum; ++row) - for (int col = 0; col < packedColNumPerVBlock; ++col) { - - unsigned operandStartOffset = (vblk * packedRowNum + row) * - packedColNumPerVBlock * - packedElemsPerLanePerDPASInst; - - SmallVector indices(packedElemsPerLanePerDPASInst); - for (int elemIdx = 0; elemIdx < packedElemsPerLanePerDPASInst; - ++elemIdx) { - indices[elemIdx] = operandStartOffset + - elemIdx * packedColNumPerVBlock + col; - LLVM_DEBUG({ - llvm::dbgs() << "indices[" << elemIdx << "]" << " = " - << indices[elemIdx] << "\n"; - }); - } - DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices); - Value loadVal = LLVM::ShuffleVectorOp::create( - rewriter, loc, packedDPASOperandType, load2dOp, load2dOp, - attr); - - // Save the decomposed vals to the map; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: { - LLVM_DEBUG({ - llvm::dbgs() << "load vals index: " - << std::to_string(outer * packedRowNum * - numLoadPerOutRepCluster + - rep * packedRowNum + row) - << ", " - << std::to_string( - k + vblk * packedColNumPerVBlock + col) - << "\n"; - }); - loadVals[{outer * packedRowNum * numLoadPerOutRepCluster + - rep * packedRowNum + row, - k + vblk * packedColNumPerVBlock + col}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - LLVM_DEBUG({ - llvm::dbgs() - << "load vals index: " - << std::to_string(outer * packedColNum * - numLoadPerOutRepCluster + - rep * packedColNum + - vblk * packedColNumPerVBlock + col) - << ", " << std::to_string(k + row) << "\n"; - }); - loadVals[{outer * packedColNum * numLoadPerOutRepCluster + - rep * packedColNum + - vblk * packedColNumPerVBlock + col, - k + row}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - case DpasEncodingAttr::OpIdx::OperandC: { - llvm_unreachable("unexpected OpIdx::OperandC"); - } break; + unsigned packedRowNum = opIdx == DpasEncodingAttr::OpIdx::OperandA + ? numOperandsOuterDimPerLoad + : numOperandsInnerDimPerLoad; + unsigned packedColNum = opIdx == DpasEncodingAttr::OpIdx::OperandA + ? numOperandsInnerDimPerLoad + : numOperandsOuterDimPerLoad; + + unsigned packedColNumPerVBlock = packedColNum / vBlocks; + for (int vblk = 0; vblk < vBlocks; ++vblk) + for (int row = 0; row < packedRowNum; ++row) + for (int col = 0; col < packedColNumPerVBlock; ++col) { + unsigned operandStartOffset = (vblk * packedRowNum + row) * + packedColNumPerVBlock * + packedElemsPerLanePerDPASInst; + + SmallVector indices(packedElemsPerLanePerDPASInst); + for (int elemIdx = 0; elemIdx < packedElemsPerLanePerDPASInst; + ++elemIdx) + indices[elemIdx] = operandStartOffset + + elemIdx * packedColNumPerVBlock + col; + + DenseI32ArrayAttr attr = + rewriter.getDenseI32ArrayAttr(indices); + Value loadVal = LLVM::ShuffleVectorOp::create( + rewriter, loc, packedDPASOperandType, load2dOp, load2dOp, + attr); + + switch (opIdx) { + case DpasEncodingAttr::OpIdx::OperandA: + loadVals[{outer * packedRowNum * numLoadPerOutRepCluster + + rep * packedRowNum + row, + k + vblk * packedColNumPerVBlock + col}] = + b.bitcast(loadVal, unpackedDPASOperandType); + break; + case DpasEncodingAttr::OpIdx::OperandB: + loadVals[{outer * packedColNum * numLoadPerOutRepCluster + + rep * packedColNum + + vblk * packedColNumPerVBlock + col, + k + row}] = + b.bitcast(loadVal, unpackedDPASOperandType); + break; + case DpasEncodingAttr::OpIdx::OperandC: + llvm_unreachable("unexpected OpIdx::OperandC"); + } } - } + } } } - } - // Extract the value returned by the load ops. And put the values in the - // expected order for the layout. - SmallVector unpackedLoadedVals; - for (int outer = 0; outer < numRepOuter; ++outer) { - for (int k = 0; k < numRepInner; ++k) { - for (int rep = 0; rep < repCluster[unsigned(opIdx)]; ++rep) { - if (loadVals.find({outer * repCluster[unsigned(opIdx)] + rep, k}) == - loadVals.end()) { - // generate a nice error message before the throw below aborts our - // pipeline - llvm::errs() << "Failed to find key at " - << outer * repCluster[unsigned(opIdx)] + rep << ", " - << k << "\n"; - } - Value loadVal = - loadVals.at({outer * repCluster[unsigned(opIdx)] + rep, k}); - VectorType loadTy = cast(loadVal.getType()); - for (int i = 0; i < loadTy.getNumElements(); ++i) { - auto val = b.extract_element(loadVal, b.i32_val(i)); - unpackedLoadedVals.push_back(val); + SmallVector unpackedLoadedVals; + for (int outer = 0; outer < numRepOuter; ++outer) { + for (int k = 0; k < numRepInner; ++k) { + for (int rep = 0; rep < repCluster[unsigned(opIdx)]; ++rep) { + auto it = + loadVals.find({outer * repCluster[unsigned(opIdx)] + rep, k}); + if (it == loadVals.end()) { + llvm::errs() << "Failed to find key at " + << outer * repCluster[unsigned(opIdx)] + rep << ", " + << k << "\n"; + llvm::report_fatal_error("loadVals lookup failed"); + } + Value loadVal = it->second; + VectorType loadTy = cast(loadVal.getType()); + for (int i = 0; i < loadTy.getNumElements(); ++i) { + auto val = b.extract_element(loadVal, b.i32_val(i)); + unpackedLoadedVals.push_back(val); + } } } } - } - Type llvmResultStructTy = typeConverter->convertType(op.getType()); - Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, - rewriter, llvmResultStructTy); - rewriter.replaceOp(op, {resultStruct}); + Type llvmResultStructTy = typeConverter->convertType(op.getType()); + return packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter, + llvmResultStructTy); + }; + Value finalResult; + if (!needRuntimePitchCheck) { + // No spec-const gating; always use block-IO. + finalResult = buildBlockIOResult(); + } else { + // === spec-constant-based branch === + MLIRContext *ctx = rewriter.getContext(); + ModuleOp module = op->getParentOfType(); + + auto i32Ty = IntegerType::get(ctx, 32); + auto fnTy = LLVM::LLVMFunctionType::get( + i32Ty, ArrayRef{i32Ty, i32Ty}, /*isVarArg=*/false); + + LLVM::LLVMFuncOp specFn = + module.lookupSymbol("__spirv_SpecConstant"); + if (!specFn) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + ImplicitLocOpBuilder ib(loc, rewriter); + specFn = LLVM::LLVMFuncOp::create(ib, "__spirv_SpecConstant", fnTy); + } + + Value specIdVal = LLVM::ConstantOp::create( + rewriter, loc, i32Ty, + rewriter.getI32IntegerAttr(kBlockIOPitchSpecId)); + + Value defaultPitchBytes = LLVM::ConstantOp::create( + rewriter, loc, i32Ty, rewriter.getI32IntegerAttr(0)); + + auto call = LLVM::CallOp::create( + rewriter, loc, TypeRange{i32Ty}, SymbolRefAttr::get(specFn), + ValueRange{specIdVal, defaultPitchBytes}); + + Value specPitchBytes = call.getResult(); + + // cond = (specPitchBytes >= 64) + Value cond = b.icmp_sge(specPitchBytes, b.i32_val(64)); + + // Generic fallback lowering (gather load). + Value genericResult = emitGenericLoad(op, + adaptor.getPtr(), // llPtr + adaptor.getMask(), // llMask + adaptor.getOther(), // llOther + rewriter, typeConverter, *this); + + auto createBlockIOResult = [&]() -> SmallVector { + Value blockIOResult = buildBlockIOResult(); + return {blockIOResult}; + }; + + Block &mergeBlock = LLVM::intel::createPredicatedBlock( + rewriter, loc, + cond, // then: block-IO path + SmallVector{genericResult}, // else: generic path + createBlockIOResult); + + finalResult = mergeBlock.getArgument(0); + } + + rewriter.replaceOp(op, finalResult); return success(); } @@ -2426,12 +2446,14 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} - LogicalResult - matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + /// Generic lowering for triton::LoadOp → LLVM struct value. + static Value emitGenericLoadImpl(triton::LoadOp op, Value llPtr, Value llMask, + Value llOther, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + const LoadStoreConversionBase &base) { Location loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); - auto typeConverter = getTypeConverter(); MLIRContext *ctx = rewriter.getContext(); // original values @@ -2439,18 +2461,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, Value mask = op.getMask(); Value other = op.getOther(); - // adaptor values - Value llPtr = adaptor.getPtr(); - Value llMask = adaptor.getMask(); - Value llOther = adaptor.getOther(); - // Determine the vectorization size Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(op.getType())); unsigned numElems = getTotalElemsPerThread(op.getType()); - unsigned vec = getVectorSize(ptr); + unsigned vec = base.getVectorSize(ptr); if (llMask) - vec = std::min(vec, getMaskAlignment(mask)); + vec = std::min(vec, base.getMaskAlignment(mask)); SmallVector ptrElems, maskElems, otherElems; bool otherIsSplatConstInt = false; @@ -2459,9 +2476,10 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, if (isTensorPointerType(ptr.getType())) { // fallback to gather load. auto tensorType = cast(op.getType()); - std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr( - loc, llPtr, tensorType, valueElemTy, rewriter, op.getBoundaryCheck(), - op.getPadding()); + std::tie(ptrElems, maskElems, otherElems) = + base.convertBlockPtrToTensorOfPtr(loc, llPtr, tensorType, valueElemTy, + rewriter, op.getBoundaryCheck(), + op.getPadding()); } else { // Get the LLVM values for pointers ptrElems = unpackLLElements(loc, llPtr, rewriter); @@ -2503,19 +2521,19 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, if (unsigned canonicalVecStart = getCanonicalIndex(vecStart, regMask); vecStart != canonicalVecStart) { // For redundant registers, refer back to the canonical load - for (int iVec = 0; iVec < vec; ++iVec) + for (int iVec = 0; iVec < static_cast(vec); ++iVec) loadedVals.push_back(loadedVals[canonicalVecStart + iVec]); - continue; } // TODO: optimization when ptr is GEP with constant offset - const size_t maxWordWidth = std::max(32, valueElemNBits); + const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; const size_t width = std::min(totalWidth, maxWordWidth); - const size_t nWords = std::max(1, totalWidth / width); + const size_t nWords = std::max(1, totalWidth / width); const size_t wordNElems = width / valueElemNBits; const size_t movWidth = width < 16 ? 16 : width; + (void)movWidth; // keep variable but silence unused warning assert(wordNElems * nWords * numVecs == numElems); Value pred = maskElems.size() ? maskElems[vecStart] : Value{}; @@ -2554,9 +2572,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, retTy, other_, v, createIndexAttrConstant( rewriter, loc, typeConverter->getIndexType(), ii)) - : - - v; + : v; } } assert(other_ && "Expecting a valid value"); @@ -2566,13 +2582,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto createLoadWithAttrs = [&]() { return SmallVector{b.load(retTy, addrElem, alignment, op.getIsVolatile(), - getNonTemporalFlag(op))}; + base.getNonTemporalFlag(op))}; }; Value ret; if (!pred) ret = createLoadWithAttrs()[0]; - else if (canUsePredicatedInstructions(op)) + else if (base.canUsePredicatedInstructions(op)) ret = TritonGEN::PredicatedLoadOp::create( rewriter, loc, retTy, addrElem, b.i64_val(alignment), pred, other_); else { @@ -2604,13 +2620,29 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, } // end vec Type llvmResultStructTy = typeConverter->convertType(op.getType()); - Value resultStruct = packLLElements(loc, typeConverter, loadedVals, - rewriter, llvmResultStructTy); + return packLLElements(loc, typeConverter, loadedVals, rewriter, + llvmResultStructTy); + } + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value resultStruct = emitGenericLoadImpl( + op, adaptor.getPtr(), adaptor.getMask(), adaptor.getOther(), rewriter, + getTypeConverter(), *this); rewriter.replaceOp(op, {resultStruct}); return success(); } }; +static Value emitGenericLoad(triton::LoadOp op, Value llPtr, Value llMask, + Value llOther, ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + const LoadStoreConversionBase &base) { + return LoadOpConversion::emitGenericLoadImpl(op, llPtr, llMask, llOther, + rewriter, typeConverter, base); +} + struct StoreOpToBlockIOConversion : public ConvertTritonGPUOpToLLVMPattern, public BlockIOConversionBase {