@@ -363,6 +363,29 @@ static void atomicStore(OpBuilder &builder, Location loc,
363363 builder.create <memref::AtomicYieldOp>(loc, scalarMaskedValue);
364364}
365365
366+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
367+ // / It has similar logic to `atomicStore`, but without the atomicity.
368+ static void rmwStore (OpBuilder &builder, Location loc,
369+ MemRefValue linearizedMemref, Value linearizedIndex,
370+ VectorValue valueToStore, Value mask) {
371+ assert (valueToStore.getType ().getRank () == 1 && " expected 1-D vector" );
372+
373+ // Load the original value from memory, and cast it to the original element
374+ // type.
375+ auto oneElemVecType =
376+ VectorType::get ({1 }, linearizedMemref.getType ().getElementType ());
377+ Value origVecValue = builder.create <vector::LoadOp>(
378+ loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
379+ origVecValue = builder.create <vector::BitCastOp>(loc, valueToStore.getType (),
380+ origVecValue);
381+
382+ // Construct the final masked value and yield it.
383+ Value maskedValue = selectAndCast (builder, loc, oneElemVecType, mask,
384+ origVecValue, valueToStore);
385+ builder.create <vector::StoreOp>(loc, maskedValue, linearizedMemref,
386+ linearizedIndex);
387+ }
388+
366389// / Extract `sliceNumElements` from source `vector` at `extractOffset`,
367390// / and insert it into an empty vector at `insertOffset`.
368391// / Inputs:
@@ -405,6 +428,10 @@ namespace {
405428struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
406429 using OpConversionPattern::OpConversionPattern;
407430
431+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
432+ : OpConversionPattern<vector::StoreOp>(context),
433+ useAtomicWrites_ (useAtomicWrites) {}
434+
408435 LogicalResult
409436 matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
410437 ConversionPatternRewriter &rewriter) const override {
@@ -611,13 +638,31 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
611638 auto backMask = rewriter.create <arith::ConstantOp>(
612639 loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
613640
614- atomicStore (rewriter, loc, memrefBase, currentDestIndex,
615- cast<VectorValue>(subWidthStorePart), backMask.getResult ());
641+ subEmulatedWidthStore (rewriter, loc, memrefBase, currentDestIndex,
642+ cast<VectorValue>(subWidthStorePart),
643+ backMask.getResult ());
616644 }
617645
618646 rewriter.eraseOp (op);
619647 return success ();
620648 }
649+
650+ // / Store a subbyte-sized value to memory, with a mask. Depending on the
651+ // / configuration, it could be an atomic store or a non-atomic RMW sequence.
652+ template <typename ... Args>
653+ void subEmulatedWidthStore (Args &&...args) const {
654+ static_assert (
655+ std::is_same_v<decltype (atomicStore), decltype (rmwStore)> &&
656+ " `atomicStore` and `rmwStore` must have same signature, as per "
657+ " the design to keep the code clean, which one to call is "
658+ " determined by the `useAtomicWrites` flag." );
659+ std::function<decltype (atomicStore)> storeFunc =
660+ useAtomicWrites_ ? atomicStore : rmwStore;
661+ storeFunc (std::forward<Args>(args)...);
662+ }
663+
664+ private:
665+ const bool useAtomicWrites_;
621666};
622667
623668// ===----------------------------------------------------------------------===//
@@ -1930,12 +1975,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
19301975
19311976void vector::populateVectorNarrowTypeEmulationPatterns (
19321977 const arith::NarrowTypeEmulationConverter &typeConverter,
1933- RewritePatternSet &patterns) {
1978+ RewritePatternSet &patterns, bool useAtomicWrites ) {
19341979
19351980 // Populate `vector.*` conversion patterns.
1936- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1981+ // TODO: #119553 support atomicity
1982+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
19371983 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
19381984 typeConverter, patterns.getContext ());
1985+
1986+ // Populate `vector.*` store conversion patterns. The caller can choose
1987+ // to avoid emitting atomic operations and reduce it to load-modify-write
1988+ // sequence for stores if it is known there are no thread contentions.
1989+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
19391990}
19401991
19411992void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments