Skip to content

Commit fc9e01b

Browse files
[Intel] Cleanup loadSharedToDistributed and storeDistributedToShared (#2985)
Their implementations are identical to the upstream version. This PR also sync `MemoryOpToLLVM.cpp` from upstream. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 3b860a6 commit fc9e01b

File tree

2 files changed

+34
-62
lines changed

2 files changed

+34
-62
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,39 @@ void lowerDistributedToShared(
3333
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
3434
auto elemTy = typeConverter->convertType(srcTy.getElementType());
3535

36-
auto smemBase = smemObj.getBase();
37-
auto dstStrides = smemObj.getStrides();
3836
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
39-
mlir::triton::intel::storeDistributedToShared(dstTy, srcTy, elemTy, inVals,
40-
smemObj, loc, rewriter,
41-
targetInfo, llvmOpCount);
37+
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
38+
targetInfo, llvmOpCount);
4239
}
4340

41+
struct GlobalScratchAllocOpConversion
42+
: public ConvertOpToLLVMPattern<triton::gpu::GlobalScratchAllocOp> {
43+
GlobalScratchAllocOpConversion(LLVMTypeConverter &converter,
44+
PatternBenefit benefit)
45+
: ConvertOpToLLVMPattern(converter, benefit) {}
46+
47+
LogicalResult
48+
matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor,
49+
ConversionPatternRewriter &rewriter) const override {
50+
Location loc = op.getLoc();
51+
52+
auto opOffsetAttr = op->getAttrOfType<mlir::IntegerAttr>(
53+
"ttg.global_scratch_memory_offset");
54+
assert(opOffsetAttr);
55+
auto opOffset = opOffsetAttr.getValue().getZExtValue();
56+
57+
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
58+
if (!funcOp) {
59+
return failure();
60+
}
61+
Value ptr =
62+
LLVM::getGlobalScratchPtr(loc, rewriter, funcOp, i32_val(opOffset));
63+
64+
rewriter.replaceOp(op, ptr);
65+
return success();
66+
}
67+
};
68+
4469
struct LocalAllocOpConversion
4570
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::LocalAllocOp> {
4671
LocalAllocOpConversion(const LLVMTypeConverter &converter,
@@ -205,18 +230,16 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
205230
auto srcTy = op.getSrc().getType();
206231
auto dstTy = op.getResult().getType();
207232
auto dstShape = dstTy.getShape();
233+
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
208234
assert(!isa<DotOperandEncodingAttr>(dstTy.getEncoding()) &&
209235
"Unexpected rank of ConvertLayout(shared->blocked)");
210-
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
211-
auto dstLayout = dstTy.getEncoding();
212-
auto inOrd = getOrder(srcSharedLayout);
213236

214-
auto smemObj = getSharedMemoryObjectFromStruct(
237+
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
215238
loc, adaptor.getSrc(),
216239
typeConverter->convertType(srcTy.getElementType()), rewriter);
217240
auto elemLlvmTy = typeConverter->convertType(dstTy.getElementType());
218241

219-
SmallVector<Value> outVals = mlir::triton::intel::loadSharedToDistributed(
242+
SmallVector<Value> outVals = loadSharedToDistributed(
220243
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
221244

222245
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
@@ -277,6 +300,7 @@ void mlir::triton::intel::populateMemoryOpToLLVMPattern(
277300
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
278301
RewritePatternSet &patterns, PatternBenefit benefit,
279302
std::optional<BackendCallbacks> backendCallbacks) {
303+
patterns.add<GlobalScratchAllocOpConversion>(typeConverter, benefit);
280304
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
281305
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
282306
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -768,58 +768,6 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
768768
return ret;
769769
}
770770

771-
inline SmallVector<Value>
772-
loadSharedToDistributed(RankedTensorType dstTy, triton::gpu::MemDescType srcTy,
773-
Type elemLlvmTy, const SharedMemoryObject &smemObj,
774-
Location loc, RewriterBase &rewriter,
775-
const TargetInfoBase &target) {
776-
SmallVector<Value> ret;
777-
bool success = emitTransferBetweenRegistersAndShared(
778-
dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc,
779-
rewriter, target, [&](VectorType vecTy, Value vecAddr) {
780-
auto vecVal = load(vecTy, vecAddr);
781-
vecVal.setAlignment(vecTy.getNumElements() *
782-
elemLlvmTy.getIntOrFloatBitWidth() / 8);
783-
784-
for (int v = 0; v < vecTy.getNumElements(); v++) {
785-
ret.push_back(extract_element(elemLlvmTy, vecVal, i32_val(v)));
786-
}
787-
});
788-
if (!success)
789-
llvm::report_fatal_error("Failed to emit transfer from shared to register");
790-
791-
return ret;
792-
}
793-
794-
inline void
795-
storeDistributedToShared(triton::gpu::MemDescType dstTy, RankedTensorType srcTy,
796-
Type elemLlvmTy, ArrayRef<Value> srcVals,
797-
const SharedMemoryObject &smemObj, Location loc,
798-
RewriterBase &rewriter, const TargetInfoBase &target,
799-
std::pair<size_t, Type> *const llvmOpCount) {
800-
bool success = emitTransferBetweenRegistersAndShared(
801-
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc,
802-
rewriter, target, [&](VectorType vecTy, Value vecAddr) {
803-
ArrayRef<Value> vals = srcVals.take_front(vecTy.getNumElements());
804-
srcVals = srcVals.drop_front(vecTy.getNumElements());
805-
806-
Value vec = undef(vecTy);
807-
for (int i = 0; i < vals.size(); i++) {
808-
vec = insert_element(vec, vals[i], i32_val(i));
809-
}
810-
store(vec, vecAddr)
811-
.setAlignment(vecTy.getNumElements() *
812-
elemLlvmTy.getIntOrFloatBitWidth() / 8);
813-
if (llvmOpCount) {
814-
++(llvmOpCount->first);
815-
llvmOpCount->second = vecTy;
816-
}
817-
});
818-
819-
if (!success)
820-
llvm::report_fatal_error("Failed to emit transfer from register to shared");
821-
}
822-
823771
Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
824772
Value v);
825773
Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,

0 commit comments

Comments
 (0)