@@ -3154,36 +3154,21 @@ struct AtomicCASOpConversion
31543154 : valueTy;
31553155 auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth ();
31563156 auto elemsPerThread = getTotalElemsPerThread (op.getVal ().getType ());
3157- // vec = 1 for scalar
3158- auto vec = getVectorSize (op.getPtr ());
3159- // tensor
3160- if (tensorTy) {
3161- auto valTy = cast<RankedTensorType>(op.getVal ().getType ());
3162- vec = std::min<unsigned >(vec, valTy.getElementType ().isF16 () ? 2 : 1 );
3163- }
31643157
31653158 auto freeVarMasks = getFreeVariableMasks (valueTy);
31663159 Value mask =
31673160 emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
3168- auto vecTy = vec_ty (valueElemTy, vec);
31693161 SmallVector<Value> resultVals (elemsPerThread);
31703162
31713163 MemSemantic memSem = op.getSem ();
31723164 LLVM::AtomicOrdering successOrdering = getMemoryOrdering (memSem)
31733165 ? *getMemoryOrdering (memSem)
31743166 : LLVM::AtomicOrdering::acq_rel;
31753167 LLVM::AtomicOrdering failureOrdering = LLVM::AtomicOrdering::monotonic;
3176- for (size_t i = 0 ; i < elemsPerThread; i += vec) {
3177- Value casVal = b.undef (vecTy);
3178- for (int ii = 0 ; ii < vec; ++ii) {
3179- Value iiVal = createIndexAttrConstant (
3180- rewriter, loc, getTypeConverter ()->getIndexType (), ii);
3181- casVal = b.insert_element (vecTy, casVal, valElements[i + ii], iiVal);
3182- }
3183-
3168+ for (size_t i = 0 ; i < elemsPerThread; ++i) {
31843169 Value casPtr = ptrElements[i];
31853170 Value casCmp = cmpElements[i];
3186- casVal = valElements[i];
3171+ Value casVal = valElements[i];
31873172
31883173 assert ((valueElemNBits == 32 || valueElemNBits == 64 ) &&
31893174 " Unexpected width" );
@@ -3212,15 +3197,11 @@ struct AtomicCASOpConversion
32123197 } else {
32133198 ret = createAtomicCASInstruction ()[0 ];
32143199 }
3215- Type retType = (!tensorTy || vec == 1 ) ? valueElemTy : vecTy;
3216- ret = b.bitcast (ret, retType );
3200+
3201+ ret = b.bitcast (ret, valueElemTy );
32173202
32183203 if (tensorTy) {
3219- for (int ii = 0 ; ii < vec; ++ii) {
3220- resultVals[i + ii] =
3221- vec == 1 ? ret
3222- : b.extract_element (valueElemTy, ret, b.i32_val (ii));
3223- }
3204+ resultVals[i] = ret;
32243205 } else {
32253206 if (op.getResult ().use_empty ()) {
32263207 rewriter.eraseOp (op);
0 commit comments