@@ -33,14 +33,39 @@ void lowerDistributedToShared(
3333 auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding ()).getOrder ();
3434 auto elemTy = typeConverter->convertType (srcTy.getElementType ());
3535
36- auto smemBase = smemObj.getBase ();
37- auto dstStrides = smemObj.getStrides ();
3836 auto inVals = unpackLLElements (loc, adaptorSrc, rewriter);
39- mlir::triton::intel::storeDistributedToShared (dstTy, srcTy, elemTy, inVals,
40- smemObj, loc, rewriter,
41- targetInfo, llvmOpCount);
37+ storeDistributedToShared (dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
38+ targetInfo, llvmOpCount);
4239}
4340
41+ struct GlobalScratchAllocOpConversion
42+ : public ConvertOpToLLVMPattern<triton::gpu::GlobalScratchAllocOp> {
43+ GlobalScratchAllocOpConversion (LLVMTypeConverter &converter,
44+ PatternBenefit benefit)
45+ : ConvertOpToLLVMPattern(converter, benefit) {}
46+
47+ LogicalResult
48+ matchAndRewrite (triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor,
49+ ConversionPatternRewriter &rewriter) const override {
50+ Location loc = op.getLoc ();
51+
52+ auto opOffsetAttr = op->getAttrOfType <mlir::IntegerAttr>(
53+ " ttg.global_scratch_memory_offset" );
54+ assert (opOffsetAttr);
55+ auto opOffset = opOffsetAttr.getValue ().getZExtValue ();
56+
57+ auto funcOp = op->getParentOfType <LLVM::LLVMFuncOp>();
58+ if (!funcOp) {
59+ return failure ();
60+ }
61+ Value ptr =
62+ LLVM::getGlobalScratchPtr (loc, rewriter, funcOp, i32_val (opOffset));
63+
64+ rewriter.replaceOp (op, ptr);
65+ return success ();
66+ }
67+ };
68+
4469struct LocalAllocOpConversion
4570 : public ConvertTritonGPUOpToLLVMPattern<triton::gpu::LocalAllocOp> {
4671 LocalAllocOpConversion (const LLVMTypeConverter &converter,
@@ -205,18 +230,16 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
205230 auto srcTy = op.getSrc ().getType ();
206231 auto dstTy = op.getResult ().getType ();
207232 auto dstShape = dstTy.getShape ();
233+ auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding ());
208234 assert (!isa<DotOperandEncodingAttr>(dstTy.getEncoding ()) &&
209235 " Unexpected rank of ConvertLayout(shared->blocked)" );
210- auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding ());
211- auto dstLayout = dstTy.getEncoding ();
212- auto inOrd = getOrder (srcSharedLayout);
213236
214- auto smemObj = getSharedMemoryObjectFromStruct (
237+ auto smemObj = LLVM:: getSharedMemoryObjectFromStruct (
215238 loc, adaptor.getSrc (),
216239 typeConverter->convertType (srcTy.getElementType ()), rewriter);
217240 auto elemLlvmTy = typeConverter->convertType (dstTy.getElementType ());
218241
219- SmallVector<Value> outVals = mlir::triton::intel:: loadSharedToDistributed (
242+ SmallVector<Value> outVals = loadSharedToDistributed (
220243 dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
221244
222245 Value result = packLLElements (loc, typeConverter, outVals, rewriter, dstTy);
@@ -277,6 +300,7 @@ void mlir::triton::intel::populateMemoryOpToLLVMPattern(
277300 LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
278301 RewritePatternSet &patterns, PatternBenefit benefit,
279302 std::optional<BackendCallbacks> backendCallbacks) {
303+ patterns.add <GlobalScratchAllocOpConversion>(typeConverter, benefit);
280304 patterns.add <LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
281305 patterns.add <LocalDeallocOpConversion>(typeConverter, benefit);
282306 patterns.add <LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
0 commit comments