Skip to content

Commit cee69b4

Browse files
committed
[intel] specialized kernels
1 parent 2b61b09 commit cee69b4

File tree

2 files changed

+128
-28
lines changed

2 files changed

+128
-28
lines changed

third_party/intel/backend/include/sycl_functions.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ inline std::optional<bool> isEnvValueBool(std::string str) {
143143
return std::nullopt;
144144
}
145145

146+
static constexpr int kBlockIOPitchSpecId = 123;
147+
146148
std::tuple<ze_module_handle_t, ze_result_t>
147149
create_module(ze_context_handle_t context, ze_device_handle_t device,
148150
uint8_t *binary_ptr, size_t binary_size, const char *build_flags,
@@ -152,12 +154,25 @@ create_module(ze_context_handle_t context, ze_device_handle_t device,
152154

153155
const ze_module_format_t format =
154156
is_spv ? ZE_MODULE_FORMAT_IL_SPIRV : ZE_MODULE_FORMAT_NATIVE;
157+
158+
uint64_t pitchBytesMode = (useBlockIO ? 64u : 0u); // TODO remove
159+
160+
ze_module_constants_t specConsts{};
161+
uint32_t ids[] = {kBlockIOPitchSpecId};
162+
uint64_t values[] = {pitchBytesMode};
163+
164+
specConsts.numConstants = 1;
165+
specConsts.pConstantIds = ids;
166+
specConsts.pConstantValues = values;
167+
155168
ze_module_desc_t module_description = {};
156169
module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
157170
module_description.format = format;
158171
module_description.inputSize = static_cast<uint32_t>(binary_size);
159172
module_description.pInputModule = binary_ptr;
160173
module_description.pBuildFlags = build_flags;
174+
module_description.pConstants = &specConsts;
175+
161176
ze_module_build_log_handle_t buildlog;
162177
ze_module_handle_t module;
163178
auto error_no =

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 113 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ static int __builtin_ctz(unsigned x) {
3939

4040
namespace {
4141

42+
static constexpr int kBlockIOPitchSpecId = 123;
43+
4244
Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) {
4345
auto tb = TritonLLVMOpBuilder(loc, rewriter);
4446
if (a && b) {
@@ -338,6 +340,11 @@ struct LoadStoreConversionBase {
338340
triton::tools::getBoolEnv("TRITON_INTEL_PREDICATED");
339341
};
340342

343+
static Value emitGenericLoad(triton::LoadOp op, Value llPtr, Value llMask,
344+
Value llOther, ConversionPatternRewriter &rewriter,
345+
const LLVMTypeConverter *typeConverter,
346+
const LoadStoreConversionBase &base);
347+
341348
struct BlockIOConversionBase : public LoadStoreConversionBase {
342349
explicit BlockIOConversionBase(
343350
const triton::intel::TargetInfo &targetInfo,
@@ -1659,9 +1666,13 @@ struct LoadOpToBlockIOConversion
16591666
std::swap(baseWidth, baseHeight);
16601667
}
16611668
// HW requires the pitch to be at least 64 bytes.
1669+
bool needRuntimePitchCheck = false;
1670+
16621671
if (auto pitchConst = mlir::triton::intel::getFoldedConstantValue(pitch)) {
16631672
if ((*pitchConst * elemSizeInBits / 8) < 64)
16641673
return failure();
1674+
} else {
1675+
needRuntimePitchCheck = true;
16651676
}
16661677

16671678
baseWidth = b.trunc(i32_ty, baseWidth);
@@ -1889,10 +1900,72 @@ struct LoadOpToBlockIOConversion
18891900
}
18901901

18911902
Type llvmResultStructTy = typeConverter->convertType(op.getType());
1892-
Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals,
1893-
rewriter, llvmResultStructTy);
1894-
rewriter.replaceOp(op, {resultStruct});
18951903

1904+
Value blockIOResult = packLLElements(loc, typeConverter, unpackedLoadedVals,
1905+
rewriter, llvmResultStructTy);
1906+
1907+
Value finalResult;
1908+
if (!needRuntimePitchCheck) {
1909+
finalResult = blockIOResult;
1910+
} else {
1911+
MLIRContext *ctx = rewriter.getContext();
1912+
ModuleOp module = op->getParentOfType<ModuleOp>();
1913+
1914+
auto i32Ty = IntegerType::get(ctx, 32);
1915+
auto fnTy = LLVM::LLVMFunctionType::get(
1916+
i32Ty, ArrayRef<Type>{i32Ty, i32Ty}, /*isVarArg=*/false);
1917+
1918+
LLVM::LLVMFuncOp specFn =
1919+
module.lookupSymbol<LLVM::LLVMFuncOp>("__spirv_SpecConstant");
1920+
if (!specFn) {
1921+
PatternRewriter::InsertionGuard guard(rewriter);
1922+
rewriter.setInsertionPointToStart(module.getBody());
1923+
1924+
ImplicitLocOpBuilder ib(loc, rewriter);
1925+
specFn = LLVM::LLVMFuncOp::create(ib, "__spirv_SpecConstant", fnTy);
1926+
// default linkage is External
1927+
}
1928+
1929+
// Default value (in bytes) if host doesn't specialize this ID.
1930+
// Using 0 means "disable block-IO by default".
1931+
Value specIdVal = LLVM::ConstantOp::create(
1932+
rewriter, loc, i32Ty,
1933+
rewriter.getI32IntegerAttr(kBlockIOPitchSpecId));
1934+
1935+
Value defaultPitchBytes = LLVM::ConstantOp::create(
1936+
rewriter, loc, i32Ty, rewriter.getI32IntegerAttr(0));
1937+
1938+
// llvm.call @__spirv_SpecConstant(i32 specId, i32 default) -> i32
1939+
auto call = LLVM::CallOp::create(
1940+
rewriter, loc, TypeRange{i32Ty}, SymbolRefAttr::get(specFn),
1941+
ValueRange{specIdVal, defaultPitchBytes});
1942+
1943+
Value specPitchBytes = call.getResult();
1944+
1945+
// cond = (specPitchBytes >= 64)
1946+
Value cond = b.icmp_sge(specPitchBytes, b.i32_val(64));
1947+
1948+
// Generic fallback lowering (gather load).
1949+
Value genericResult = emitGenericLoad(op,
1950+
adaptor.getPtr(), // llPtr
1951+
adaptor.getMask(), // llMask
1952+
adaptor.getOther(), // llOther
1953+
rewriter, typeConverter, *this);
1954+
1955+
auto createBlockIOResult = [&]() -> SmallVector<Value, 1> {
1956+
return {blockIOResult};
1957+
};
1958+
1959+
Block &mergeBlock = LLVM::intel::createPredicatedBlock(
1960+
rewriter, loc,
1961+
cond, // true → block-IO
1962+
SmallVector<Value, 1>{genericResult}, // false → generic
1963+
createBlockIOResult);
1964+
1965+
finalResult = mergeBlock.getArgument(0);
1966+
}
1967+
1968+
rewriter.replaceOp(op, finalResult);
18961969
return success();
18971970
}
18981971

@@ -2426,31 +2499,28 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
24262499
: ConvertOpToLLVMPattern<triton::LoadOp>(converter, benefit),
24272500
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
24282501

2429-
LogicalResult
2430-
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
2431-
ConversionPatternRewriter &rewriter) const override {
2502+
/// Generic lowering for triton::LoadOp → LLVM struct value.
2503+
static Value emitGenericLoadImpl(triton::LoadOp op, Value llPtr, Value llMask,
2504+
Value llOther,
2505+
ConversionPatternRewriter &rewriter,
2506+
const LLVMTypeConverter *typeConverter,
2507+
const LoadStoreConversionBase &base) {
24322508
Location loc = op->getLoc();
24332509
auto b = TritonLLVMOpBuilder(loc, rewriter);
2434-
auto typeConverter = getTypeConverter();
24352510
MLIRContext *ctx = rewriter.getContext();
24362511

24372512
// original values
24382513
Value ptr = op.getPtr();
24392514
Value mask = op.getMask();
24402515
Value other = op.getOther();
24412516

2442-
// adaptor values
2443-
Value llPtr = adaptor.getPtr();
2444-
Value llMask = adaptor.getMask();
2445-
Value llOther = adaptor.getOther();
2446-
24472517
// Determine the vectorization size
24482518
Type valueElemTy =
24492519
typeConverter->convertType(getElementTypeOrSelf(op.getType()));
24502520
unsigned numElems = getTotalElemsPerThread(op.getType());
2451-
unsigned vec = getVectorSize(ptr);
2521+
unsigned vec = base.getVectorSize(ptr);
24522522
if (llMask)
2453-
vec = std::min<size_t>(vec, getMaskAlignment(mask));
2523+
vec = std::min<std::size_t>(vec, base.getMaskAlignment(mask));
24542524

24552525
SmallVector<Value> ptrElems, maskElems, otherElems;
24562526
bool otherIsSplatConstInt = false;
@@ -2459,9 +2529,10 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
24592529
if (isTensorPointerType(ptr.getType())) {
24602530
// fallback to gather load.
24612531
auto tensorType = cast<RankedTensorType>(op.getType());
2462-
std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr(
2463-
loc, llPtr, tensorType, valueElemTy, rewriter, op.getBoundaryCheck(),
2464-
op.getPadding());
2532+
std::tie(ptrElems, maskElems, otherElems) =
2533+
base.convertBlockPtrToTensorOfPtr(loc, llPtr, tensorType, valueElemTy,
2534+
rewriter, op.getBoundaryCheck(),
2535+
op.getPadding());
24652536
} else {
24662537
// Get the LLVM values for pointers
24672538
ptrElems = unpackLLElements(loc, llPtr, rewriter);
@@ -2503,19 +2574,19 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
25032574
if (unsigned canonicalVecStart = getCanonicalIndex(vecStart, regMask);
25042575
vecStart != canonicalVecStart) {
25052576
// For redundant registers, refer back to the canonical load
2506-
for (int iVec = 0; iVec < vec; ++iVec)
2577+
for (int iVec = 0; iVec < static_cast<int>(vec); ++iVec)
25072578
loadedVals.push_back(loadedVals[canonicalVecStart + iVec]);
2508-
25092579
continue;
25102580
}
25112581

25122582
// TODO: optimization when ptr is GEP with constant offset
2513-
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
2583+
const size_t maxWordWidth = std::max<std::size_t>(32, valueElemNBits);
25142584
const size_t totalWidth = valueElemNBits * vec;
25152585
const size_t width = std::min(totalWidth, maxWordWidth);
2516-
const size_t nWords = std::max<size_t>(1, totalWidth / width);
2586+
const size_t nWords = std::max<std::size_t>(1, totalWidth / width);
25172587
const size_t wordNElems = width / valueElemNBits;
25182588
const size_t movWidth = width < 16 ? 16 : width;
2589+
(void)movWidth; // keep variable but silence unused warning
25192590
assert(wordNElems * nWords * numVecs == numElems);
25202591

25212592
Value pred = maskElems.size() ? maskElems[vecStart] : Value{};
@@ -2554,9 +2625,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
25542625
retTy, other_, v,
25552626
createIndexAttrConstant(
25562627
rewriter, loc, typeConverter->getIndexType(), ii))
2557-
:
2558-
2559-
v;
2628+
: v;
25602629
}
25612630
}
25622631
assert(other_ && "Expecting a valid value");
@@ -2566,13 +2635,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
25662635
auto createLoadWithAttrs = [&]() {
25672636
return SmallVector<Value>{b.load(retTy, addrElem, alignment,
25682637
op.getIsVolatile(),
2569-
getNonTemporalFlag(op))};
2638+
base.getNonTemporalFlag(op))};
25702639
};
25712640

25722641
Value ret;
25732642
if (!pred)
25742643
ret = createLoadWithAttrs()[0];
2575-
else if (canUsePredicatedInstructions(op))
2644+
else if (base.canUsePredicatedInstructions(op))
25762645
ret = TritonGEN::PredicatedLoadOp::create(
25772646
rewriter, loc, retTy, addrElem, b.i64_val(alignment), pred, other_);
25782647
else {
@@ -2604,13 +2673,29 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
26042673
} // end vec
26052674

26062675
Type llvmResultStructTy = typeConverter->convertType(op.getType());
2607-
Value resultStruct = packLLElements(loc, typeConverter, loadedVals,
2608-
rewriter, llvmResultStructTy);
2676+
return packLLElements(loc, typeConverter, loadedVals, rewriter,
2677+
llvmResultStructTy);
2678+
}
2679+
2680+
LogicalResult
2681+
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
2682+
ConversionPatternRewriter &rewriter) const override {
2683+
Value resultStruct = emitGenericLoadImpl(
2684+
op, adaptor.getPtr(), adaptor.getMask(), adaptor.getOther(), rewriter,
2685+
getTypeConverter(), *this);
26092686
rewriter.replaceOp(op, {resultStruct});
26102687
return success();
26112688
}
26122689
};
26132690

2691+
static Value emitGenericLoad(triton::LoadOp op, Value llPtr, Value llMask,
2692+
Value llOther, ConversionPatternRewriter &rewriter,
2693+
const LLVMTypeConverter *typeConverter,
2694+
const LoadStoreConversionBase &base) {
2695+
return LoadOpConversion::emitGenericLoadImpl(op, llPtr, llMask, llOther,
2696+
rewriter, typeConverter, base);
2697+
}
2698+
26142699
struct StoreOpToBlockIOConversion
26152700
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
26162701
public BlockIOConversionBase {

0 commit comments

Comments
 (0)