Skip to content

Commit fdab3bb

Browse files
authored
Fix test_gather (#3010)
Make `getStackPointer` as interface of the `TargetInfo` to generalize `getSharedMemoryBase` in gather op.
1 parent c83c0ed commit fdab3bb

File tree

17 files changed

+67
-65
lines changed

17 files changed

+67
-65
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ class TargetInfoBase {
9191

9292
virtual bool supportVectorizedAtomics() const = 0;
9393

94+
virtual Value getStackPointer(RewriterBase &rewriter,
95+
FunctionOpInterface funcOp) const = 0;
96+
9497
virtual ~TargetInfoBase() {}
9598
};
9699
} // namespace mlir::triton

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -381,19 +381,6 @@ inline bool isKernel(FunctionOpInterface funcOp) {
381381
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
382382
}
383383

384-
inline Value getStackPointer(RewriterBase &rewriter,
385-
FunctionOpInterface funcOp) {
386-
// See NOTE: [Additional Function Arguments]
387-
if (!isKernel(funcOp)) {
388-
return funcOp.getArgument(funcOp.getNumArguments() - 2);
389-
}
390-
391-
auto mod = funcOp->getParentOfType<ModuleOp>();
392-
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
393-
assert(globalBase);
394-
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
395-
}
396-
397384
inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
398385
FunctionOpInterface funcOp,
399386
Value allocOffset = {}) {
@@ -457,7 +444,8 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
457444
.getValue()
458445
.getZExtValue();
459446
Value offVal = i32_val(offset);
460-
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
447+
Value base =
448+
gep(ptrTy, i8_ty, target.getStackPointer(rewriter, func), offVal);
461449
return base;
462450
}
463451

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
8383
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
8484
adaptor.getOperands(), rewriter);
8585
if (!caller->hasAttr("allocation.offset")) {
86-
auto base = LLVM::getStackPointer(rewriter, caller);
86+
auto base = targetInfo.getStackPointer(rewriter, caller);
8787
promotedOperands.push_back(base);
8888
} else {
8989
auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp);

python/test/unit/language/test_core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6304,8 +6304,6 @@ def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0:
63046304
([128, 64], [128, 128], 1),
63056305
])
63066306
def test_gather(src_shape, indices_shape, axis, device):
6307-
if is_xpu():
6308-
pytest.skip("Fail on XPU")
63096307

63106308
def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
63116309
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,19 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc,
426426

427427
int TargetInfo::getSharedAddressSpace() const { return 3; }
428428

429+
Value TargetInfo::getStackPointer(RewriterBase &rewriter,
430+
FunctionOpInterface funcOp) const {
431+
// See NOTE: [Additional Function Arguments]
432+
if (!LLVM::isKernel(funcOp)) {
433+
return funcOp.getArgument(funcOp.getNumArguments() - 2);
434+
}
435+
436+
auto mod = funcOp->getParentOfType<ModuleOp>();
437+
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
438+
assert(globalBase);
439+
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
440+
}
441+
429442
bool TargetInfo::supportVectorizedAtomics() const {
430443
// Note: not currently tested or used, but AMD generally supports vectorized
431444
// atomics.

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
6565

6666
bool supportVectorizedAtomics() const override;
6767

68+
Value getStackPointer(RewriterBase &rewriter,
69+
FunctionOpInterface funcOp) const override;
70+
6871
private:
6972
void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args,
7073
RewriterBase &rewriter, bool useStdErr) const;

third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
8686
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
8787
adaptor.getOperands(), rewriter);
8888
if (!caller->hasAttr("allocation.offset")) {
89-
auto base = LLVM::intel::getStackPointer(rewriter, caller);
89+
auto base = targetInfo.getStackPointer(rewriter, caller);
9090
promotedOperands.push_back(base);
9191
return promotedOperands;
9292
}
93-
promotedOperands.push_back(LLVM::intel::getSharedMemoryBase(
93+
promotedOperands.push_back(LLVM::getSharedMemoryBase(
9494
callOp->getLoc(), rewriter, targetInfo, callOp));
9595
return promotedOperands;
9696
}

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ struct ConvertLayoutOpConversion
237237
Attribute srcLayout = srcTy.getEncoding();
238238
Attribute dstLayout = dstTy.getEncoding();
239239

240-
Value smemBase = LLVM::intel::getSharedMemoryBase(loc, rewriter, targetInfo,
241-
op.getOperation());
240+
Value smemBase =
241+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
242242
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
243243
smemBase = bitcast(smemBase, elemPtrTy);
244244
auto shape = dstTy.getShape();
@@ -819,8 +819,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
819819
Type elementType = inVals.front().getType();
820820
auto mod = rewriter.getInsertionPoint()->getParentOfType<ModuleOp>();
821821

822-
Value smemBase = LLVM::intel::getSharedMemoryBase(
823-
loc, rewriter, targetInfo, &*rewriter.getInsertionPoint());
822+
Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo,
823+
&*rewriter.getInsertionPoint());
824824
Type ptrType = smemBase.getType();
825825

826826
int numRows = inVals.size();

third_party/intel/lib/TritonIntelGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ struct HistogramOpConversion
181181
// TODO: we could skip this for cases with num_warps=1 as long as we can
182182
// generate the right layout. Currently the warp level histogram generates
183183
// data in the default blocked layout.
184-
Value baseSharedMemPtr = LLVM::intel::getSharedMemoryBase(
185-
loc, rewriter, targetInfo, op.getOperation());
184+
Value baseSharedMemPtr =
185+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
186186
auto dstType = op.getType();
187187
Attribute dstEncoding = dstType.getEncoding();
188188
auto indices = ::intel::emitIndices(op.getLoc(), rewriter, targetInfo,

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,8 +1500,8 @@ struct AtomicCASOpConversion
15001500
rewriter.eraseOp(op);
15011501
return success();
15021502
}
1503-
Value atomPtr = LLVM::intel::getSharedMemoryBase(
1504-
loc, rewriter, targetInfo, op.getOperation());
1503+
Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo,
1504+
op.getOperation());
15051505
atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3));
15061506
targetInfo.storeShared(rewriter, loc, atomPtr, ret, mask);
15071507
createBarrier(rewriter, loc, numCTAs);
@@ -1681,8 +1681,8 @@ struct AtomicRMWOpConversion
16811681
rewriter.eraseOp(op);
16821682
return success();
16831683
}
1684-
Value atomPtr = LLVM::intel::getSharedMemoryBase(
1685-
loc, rewriter, targetInfo, op.getOperation());
1684+
Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo,
1685+
op.getOperation());
16861686
atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3));
16871687
// Only threads with rmwMask = True store the result
16881688
targetInfo.storeShared(rewriter, loc, atomPtr, ret, rmwMask);

0 commit comments

Comments
 (0)