Skip to content

Commit 1d456fd

Browse files
committed
Disable vectorization for atomic_cas on xpu backend
Signed-off-by: Witold Dziurdz <[email protected]>
1 parent 45df313 commit 1d456fd

File tree

1 file changed

+5
-24
lines changed

1 file changed

+5
-24
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,36 +3154,21 @@ struct AtomicCASOpConversion
31543154
: valueTy;
31553155
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
31563156
auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType());
3157-
// vec = 1 for scalar
3158-
auto vec = getVectorSize(op.getPtr());
3159-
// tensor
3160-
if (tensorTy) {
3161-
auto valTy = cast<RankedTensorType>(op.getVal().getType());
3162-
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
3163-
}
31643157

31653158
auto freeVarMasks = getFreeVariableMasks(valueTy);
31663159
Value mask =
31673160
emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo);
3168-
auto vecTy = vec_ty(valueElemTy, vec);
31693161
SmallVector<Value> resultVals(elemsPerThread);
31703162

31713163
MemSemantic memSem = op.getSem();
31723164
LLVM::AtomicOrdering successOrdering = getMemoryOrdering(memSem)
31733165
? *getMemoryOrdering(memSem)
31743166
: LLVM::AtomicOrdering::acq_rel;
31753167
LLVM::AtomicOrdering failureOrdering = LLVM::AtomicOrdering::monotonic;
3176-
for (size_t i = 0; i < elemsPerThread; i += vec) {
3177-
Value casVal = b.undef(vecTy);
3178-
for (int ii = 0; ii < vec; ++ii) {
3179-
Value iiVal = createIndexAttrConstant(
3180-
rewriter, loc, getTypeConverter()->getIndexType(), ii);
3181-
casVal = b.insert_element(vecTy, casVal, valElements[i + ii], iiVal);
3182-
}
3183-
3168+
for (size_t i = 0; i < elemsPerThread; ++i) {
31843169
Value casPtr = ptrElements[i];
31853170
Value casCmp = cmpElements[i];
3186-
casVal = valElements[i];
3171+
Value casVal = valElements[i];
31873172

31883173
assert((valueElemNBits == 32 || valueElemNBits == 64) &&
31893174
"Unexpected width");
@@ -3212,15 +3197,11 @@ struct AtomicCASOpConversion
32123197
} else {
32133198
ret = createAtomicCASInstruction()[0];
32143199
}
3215-
Type retType = (!tensorTy || vec == 1) ? valueElemTy : vecTy;
3216-
ret = b.bitcast(ret, retType);
3200+
3201+
ret = b.bitcast(ret, valueElemTy);
32173202

32183203
if (tensorTy) {
3219-
for (int ii = 0; ii < vec; ++ii) {
3220-
resultVals[i + ii] =
3221-
vec == 1 ? ret
3222-
: b.extract_element(valueElemTy, ret, b.i32_val(ii));
3223-
}
3204+
resultVals[i] = ret;
32243205
} else {
32253206
if (op.getResult().use_empty()) {
32263207
rewriter.eraseOp(op);

0 commit comments

Comments
 (0)