Skip to content

Commit 2d22907

Browse files
committed
Add vectorization support for store as well
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 0b21a82 commit 2d22907

File tree

2 files changed

+22
-39
lines changed

2 files changed

+22
-39
lines changed

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,6 @@ class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl<triton::LoadOp> {
558558
// If pointers and mask both have constancy properties, those properties
559559
// will also extend to output.
560560
AxisInfo ptrInfo = operands[0]->getValue();
561-
562561
std::optional<AxisInfo> maskInfo;
563562
if (operands.size() > 1) {
564563
maskInfo = operands[1]->getValue();

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -827,46 +827,34 @@ struct LoadOpConversion
827827
rewriteTensorPointerLoad(op, adaptor, rewriter).succeeded())
828828
return success();
829829

830-
auto loc = op->getLoc();
831-
auto typeConverter = getTypeConverter();
832-
auto *ctx = rewriter.getContext();
830+
Location loc = op->getLoc();
831+
TritonIntelGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
832+
MLIRContext *ctx = rewriter.getContext();
833+
Value ptr = op.getPtr();
834+
Value mask = op.getMask();
835+
Value llMask = adaptor.getMask();
833836

834837
// Determine the vectorization size
835838
Type valueElemTy =
836839
typeConverter->convertType(getElementTypeOrSelf(op.getType()));
837840
unsigned numElems = getTotalElemsPerThread(op.getType());
838-
unsigned vec = 1;
841+
unsigned vec = getVectorSize(ptr);
842+
if (llMask)
843+
vec = std::min<size_t>(vec, getMaskAlignment(mask));
839844

840845
SmallVector<Value> ptrElems, maskElems, otherElems;
841846
bool otherIsSplatConstInt = false;
842847
int64_t splatVal = 0;
843848

844-
if (isTensorPointerType(op.getPtr().getType())) {
845-
Value ptr = op.getPtr();
846-
Value mask = op.getMask();
847-
Value llMask = adaptor.getMask();
848-
vec = getVectorSize(ptr);
849-
if (llMask)
850-
vec = std::min<size_t>(vec, getMaskAlignment(mask));
851-
852-
Type resultType = op.getType();
853-
auto tensorType = cast<RankedTensorType>(resultType);
849+
if (isTensorPointerType(ptr.getType())) {
850+
auto tensorType = cast<RankedTensorType>(op.getType());
854851
std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr(
855852
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
856853
op.getBoundaryCheck(), op.getPadding());
857854
} else {
858-
// original values
859-
Value ptr = op.getPtr();
860855
Value other = op.getOther();
861-
Value mask = op.getMask();
862-
863-
// adaptor values
864856
Value llPtr = adaptor.getPtr();
865-
Value llMask = adaptor.getMask();
866857
Value llOther = adaptor.getOther();
867-
vec = getVectorSize(ptr);
868-
if (llMask)
869-
vec = std::min<size_t>(vec, getMaskAlignment(mask));
870858

871859
// Get the LLVM values for pointers
872860
ptrElems = unpackLLElements(loc, llPtr, rewriter);
@@ -1141,39 +1129,35 @@ struct StoreOpConversion
11411129
return success();
11421130

11431131
Location loc = op->getLoc();
1132+
TritonIntelGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
11441133
MLIRContext *ctx = rewriter.getContext();
1145-
Value value = op.getValue();
1146-
11471134
Value ptr = op.getPtr();
1135+
Value mask = op.getMask();
1136+
Value llMask = adaptor.getMask();
1137+
1138+
// Determine the vectorization size
1139+
Value value = op.getValue();
11481140
Type valueTy = value.getType();
11491141
Type valueElemTy =
11501142
typeConverter->convertType(getElementTypeOrSelf(valueTy));
1151-
SmallVector<Value> ptrElems;
1152-
SmallVector<Value> maskElems;
1153-
unsigned vec = 1;
1143+
SmallVector<Value> ptrElems, maskElems;
1144+
unsigned vec = getVectorSize(ptr);
1145+
if (llMask)
1146+
vec = std::min<size_t>(vec, getMaskAlignment(mask));
11541147

11551148
if (isTensorPointerType(ptr.getType())) {
1156-
// fallback to scatter store.
11571149
auto tensorType = cast<RankedTensorType>(valueTy);
11581150
SmallVector<Value> dummyOther;
11591151
std::tie(ptrElems, maskElems, dummyOther) = convertBlockPtrToTensorOfPtr(
11601152
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
11611153
op.getBoundaryCheck());
11621154
} else {
11631155
Value llPtr = adaptor.getPtr();
1164-
Value llMask = adaptor.getMask();
1165-
1166-
vec = getVectorSize(ptr);
11671156

11681157
ptrElems = unpackLLElements(loc, llPtr, rewriter);
11691158

1170-
// Determine the vectorization size
11711159
if (llMask) {
1172-
Value mask = op.getMask();
11731160
maskElems = unpackLLElements(loc, llMask, rewriter);
1174-
1175-
unsigned maskAlign = getMaskAlignment(mask);
1176-
vec = std::min(vec, maskAlign);
11771161
}
11781162
}
11791163

@@ -1183,7 +1167,7 @@ struct StoreOpConversion
11831167
assert(!maskElems.size() ||
11841168
valueElems.size() == maskElems.size() && "Mask size mismatch");
11851169

1186-
Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
1170+
mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
11871171
const size_t dtsize =
11881172
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
11891173
const size_t valueElemNBits = dtsize * 8;

0 commit comments

Comments
 (0)