@@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
334334// /
335335// / Result:
336336// / linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
337- static void atomicRMWStore (OpBuilder &builder, Location loc,
338- MemRefValue linearizedMemref, Value storeIdx,
339- VectorValue valueToStore, Value mask) {
337+ static void atomicRMW (OpBuilder &builder, Location loc,
338+ MemRefValue linearizedMemref, Value storeIdx,
339+ VectorValue valueToStore, Value mask) {
340340 assert (valueToStore.getType ().getRank () == 1 && " expected 1-D vector" );
341341
342342 // Create an atomic load-modify-write region using
@@ -363,12 +363,11 @@ static void atomicRMWStore(OpBuilder &builder, Location loc,
363363 builder.create <memref::AtomicYieldOp>(loc, scalarMaskedValue);
364364}
365365
366- // / Generate a non-atomic read-modify-write sequence for subbyte storing.
367- // / It has similar logic to `atomicRMWStore`, but without atomicity.
368- static void nonAtomicRMWStore (OpBuilder &builder, Location loc,
369- MemRefValue linearizedMemref,
370- Value linearizedIndex, VectorValue valueToStore,
371- Value mask) {
366+ // / Generate a non-atomic read-modify-write sequence for storing to the emulated
367+ // / type. It has similar logic to `atomicRMWStore`, but without atomicity.
368+ static void nonAtomicRMW (OpBuilder &builder, Location loc,
369+ MemRefValue linearizedMemref, Value linearizedIndex,
370+ VectorValue valueToStore, Value mask) {
372371 assert (valueToStore.getType ().getRank () == 1 && " expected 1-D vector" );
373372
374373 auto oneElemVecType =
@@ -427,9 +426,9 @@ namespace {
427426struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
428427 using OpConversionPattern::OpConversionPattern;
429428
430- ConvertVectorStore (MLIRContext *context, bool useAtomicWrites )
429+ ConvertVectorStore (MLIRContext *context, bool disableAtomicRMW )
431430 : OpConversionPattern<vector::StoreOp>(context),
432- useAtomicWrites_ (useAtomicWrites ) {}
431+ disableAtomicRMW (disableAtomicRMW ) {}
433432
434433 LogicalResult
435434 matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
@@ -557,6 +556,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
557556 auto subWidthStoreMaskType =
558557 VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
559558
559+ auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
560+
560561 // 1. Partial width store for the leading byte.
561562 // When the store address is not aligned to emulated width boundary, deal
562563 // with the unaligned part so that the rest elements are aligned to width
@@ -581,8 +582,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
581582 extractSliceIntoByte (rewriter, loc, valueToStore, 0 ,
582583 frontSubWidthStoreElem, *foldedNumFrontPadElems);
583584
584- auto storeFunc = useAtomicWrites_ ? atomicRMWStore : nonAtomicRMWStore;
585-
586585 storeFunc (rewriter, loc, memrefBase, currentDestIndex,
587586 cast<VectorValue>(value), frontMask.getResult ());
588587 }
@@ -639,17 +638,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
639638 auto backMask = rewriter.create <arith::ConstantOp>(
640639 loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
641640
642- subEmulatedWidthStore (rewriter, loc, memrefBase, currentDestIndex,
643- cast<VectorValue>(subWidthStorePart),
644- backMask.getResult ());
641+ storeFunc (rewriter, loc, memrefBase, currentDestIndex,
642+ cast<VectorValue>(subWidthStorePart), backMask.getResult ());
645643 }
646644
647645 rewriter.eraseOp (op);
648646 return success ();
649647 }
650648
651649private:
652- const bool useAtomicWrites_ ;
650+ const bool disableAtomicRMW ;
653651};
654652
655653// ===----------------------------------------------------------------------===//
@@ -1962,7 +1960,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
19621960
19631961void vector::populateVectorNarrowTypeEmulationPatterns (
19641962 const arith::NarrowTypeEmulationConverter &typeConverter,
1965- RewritePatternSet &patterns, bool useAtomicWrites ) {
1963+ RewritePatternSet &patterns, bool disableAtomicRMW ) {
19661964
19671965 // Populate `vector.*` conversion patterns.
19681966 // TODO: #119553 support atomicity
@@ -1973,7 +1971,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
19731971 // Populate `vector.*` store conversion patterns. The caller can choose
19741972 // to avoid emitting atomic operations and reduce it to load-modify-write
19751973 // sequence for stores if it is known there are no thread contentions.
1976- patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites );
1974+ patterns.insert <ConvertVectorStore>(patterns.getContext (), disableAtomicRMW );
19771975}
19781976
19791977void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments