@@ -73,6 +73,7 @@ namespace {
7373#define GEN_PASS_DEF_LOWERTENSORSPASS
7474#include " xla/backends/gpu/codegen/transforms/passes.h.inc"
7575
76+ using llvm::dyn_cast_or_null;
7677using mlir::failure;
7778using mlir::Location;
7879using mlir::LogicalResult;
@@ -93,6 +94,7 @@ using mlir::ValueRange;
9394namespace arith = ::mlir::arith;
9495namespace scf = ::mlir::scf;
9596namespace ml = ::mlir::LLVM;
97+ namespace vector = ::mlir::vector;
9698
9799bool IsAMD (const se::DeviceDescription& device_description) {
98100 return std::holds_alternative<se::RocmComputeCapability>(
@@ -114,7 +116,7 @@ Value GetDestinationBuffer(Value dest) {
114116 dest.getDefiningOp <AllocateSharedOp>()) {
115117 break ;
116118 } else if (auto transfer_write =
117- dest.getDefiningOp <mlir:: vector::TransferWriteOp>()) {
119+ dest.getDefiningOp <vector::TransferWriteOp>()) {
118120 dest = transfer_write.getSource ();
119121 } else {
120122 dest.getDefiningOp ()->emitOpError (" unsupported dest type" );
@@ -168,7 +170,7 @@ struct RewriteFunctionSignatures : OpRewritePattern<mlir::func::FuncOp> {
168170 auto cast = rewriter.create <UnrealizedConversionCastOp>(
169171 op.getLoc (), operand, op.getArgument (index));
170172 op.getArgument (index).replaceAllUsesExcept (cast.getResult (0 ), cast);
171- operand = mlir::LLVM ::LLVMPointerType::get (op.getContext ());
173+ operand = ml ::LLVMPointerType::get (op.getContext ());
172174 }
173175 }
174176
@@ -188,7 +190,7 @@ Value GetPtr(Value value) {
188190 }
189191 if (auto cast = value.getDefiningOp <UnrealizedConversionCastOp>()) {
190192 if (cast.getNumOperands () == 1 && cast.getNumResults () == 1 &&
191- mlir::isa<mlir::LLVM ::LLVMPointerType>(cast.getOperand (0 ).getType ())) {
193+ mlir::isa<ml ::LLVMPointerType>(cast.getOperand (0 ).getType ())) {
192194 return cast.getOperand (0 );
193195 }
194196 }
@@ -294,25 +296,25 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
294296 return {i8_index, is_low_nibble};
295297}
296298
297- mlir::LLVM ::GEPOp CreateGep (TypedValue<mlir::RankedTensorType> tensor,
298- Value linear_index, mlir::ImplicitLocOpBuilder& b) {
299+ ml ::GEPOp CreateGep (TypedValue<mlir::RankedTensorType> tensor,
300+ Value linear_index, mlir::ImplicitLocOpBuilder& b) {
299301 Type element_type = tensor.getType ().getElementType ();
300302 if (element_type == b.getI4Type ()) {
301303 element_type = b.getI8Type ();
302304 }
303- auto ptr = mlir::LLVM ::LLVMPointerType::get (b.getContext ());
305+ auto ptr = ml ::LLVMPointerType::get (b.getContext ());
304306 auto tensor_ptr =
305307 b.create <UnrealizedConversionCastOp>(ptr, tensor).getResult (0 );
306308 mlir::LLVMTypeConverter converter (b.getContext ());
307309 auto llvm_element_type = converter.convertType (element_type);
308- auto gep = b. create <mlir::LLVM::GEPOp>(ptr, llvm_element_type, tensor_ptr,
309- linear_index);
310+ auto gep =
311+ b. create <ml::GEPOp>(ptr, llvm_element_type, tensor_ptr, linear_index);
310312 gep.setInbounds (true );
311313 return gep;
312314}
313315
314- mlir::LLVM ::GEPOp CreateGep (TypedValue<mlir::RankedTensorType> tensor,
315- ValueRange indices, mlir::ImplicitLocOpBuilder& b) {
316+ ml ::GEPOp CreateGep (TypedValue<mlir::RankedTensorType> tensor,
317+ ValueRange indices, mlir::ImplicitLocOpBuilder& b) {
316318 return CreateGep (tensor, GetLinearIndex (indices, b), b);
317319}
318320
@@ -333,8 +335,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
333335
334336 auto gep = CreateGep (op.getTensor (), linear_index, b);
335337 auto load =
336- rewriter
337- .create <mlir::LLVM::LoadOp>(gep.getLoc (), gep.getElemType (), gep)
338+ rewriter.create <ml::LoadOp>(gep.getLoc (), gep.getElemType (), gep)
338339 .getResult ();
339340
340341 if (is_low_nibble) {
@@ -359,19 +360,19 @@ Value PermutePairsInVector(Value vector, mlir::ImplicitLocOpBuilder& b) {
359360 int size = ty.getNumElements ();
360361 Value result = vector;
361362 for (int i = 0 ; i < size; i += 2 ) {
362- auto v0 = b.create <mlir:: vector::ExtractOp>(vector, i);
363- auto v1 = b.create <mlir:: vector::ExtractOp>(vector, i + 1 );
364- result = b.create <mlir:: vector::InsertOp>(v1, result, i);
365- result = b.create <mlir:: vector::InsertOp>(v0, result, i + 1 );
363+ auto v0 = b.create <vector::ExtractOp>(vector, i);
364+ auto v1 = b.create <vector::ExtractOp>(vector, i + 1 );
365+ result = b.create <vector::InsertOp>(v1, result, i);
366+ result = b.create <vector::InsertOp>(v0, result, i + 1 );
366367 }
367368 return result;
368369}
369370
370- struct RewriteTransferRead : OpRewritePattern<mlir:: vector::TransferReadOp> {
371+ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
371372 using OpRewritePattern::OpRewritePattern;
372373
373374 LogicalResult matchAndRewrite (
374- mlir:: vector::TransferReadOp op,
375+ vector::TransferReadOp op,
375376 mlir::PatternRewriter& rewriter) const override {
376377 assert (IsSupportedTransfer (op));
377378
@@ -394,8 +395,7 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
394395
395396 mlir::LLVMTypeConverter converter (b.getContext ());
396397 auto llvm_vector_type = converter.convertType (vector_type);
397- auto loaded =
398- b.create <mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult ();
398+ auto loaded = b.create <ml::LoadOp>(llvm_vector_type, gep).getResult ();
399399
400400 if (source.getType ().getElementType ().isInteger (1 )) {
401401 Value zero = b.create <mlir::arith::ConstantOp>(
@@ -484,7 +484,7 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
484484 scalar_value =
485485 b.create <UnrealizedConversionCastOp>(llvm_type, scalar_value)
486486 .getResult (0 );
487- b.create <mlir::LLVM ::StoreOp>(scalar_value, gep);
487+ b.create <ml ::StoreOp>(scalar_value, gep);
488488 op.replaceAllUsesWith (op.getDest ());
489489 }
490490
@@ -493,11 +493,11 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
493493 }
494494};
495495
496- struct RewriteTransferWrite : OpRewritePattern<mlir:: vector::TransferWriteOp> {
496+ struct RewriteTransferWrite : OpRewritePattern<vector::TransferWriteOp> {
497497 using OpRewritePattern::OpRewritePattern;
498498
499499 LogicalResult matchAndRewrite (
500- mlir:: vector::TransferWriteOp op,
500+ vector::TransferWriteOp op,
501501 mlir::PatternRewriter& rewriter) const override {
502502 assert (IsSupportedTransfer (op));
503503 Value dest = GetDestinationBuffer (op.getSource ());
@@ -526,7 +526,7 @@ struct RewriteTransferWrite : OpRewritePattern<mlir::vector::TransferWriteOp> {
526526 auto llvm_type = converter.convertType (vector_value.getType ());
527527 vector_value = b.create <UnrealizedConversionCastOp>(llvm_type, vector_value)
528528 .getResult (0 );
529- b.create <mlir::LLVM ::StoreOp>(vector_value, gep);
529+ b.create <ml ::StoreOp>(vector_value, gep);
530530
531531 rewriter.replaceOp (op, mlir::ValueRange{op.getSource ()});
532532 return success ();
@@ -550,21 +550,19 @@ struct RewriteCall : OpRewritePattern<mlir::func::CallOp> {
550550 index,
551551 rewriter
552552 .create <UnrealizedConversionCastOp>(
553- op.getLoc (),
554- mlir::LLVM::LLVMPointerType::get (op.getContext ()), arg)
553+ op.getLoc (), ml::LLVMPointerType::get (op.getContext ()), arg)
555554 .getResult (0 ));
556555 }
557556 }
558557 return success ();
559558 }
560559};
561560
562- mlir::LLVM::GlobalOp CreateGlobalOp (mlir::Attribute value,
563- const std::string& name_prefix,
564- mlir::ShapedType shaped_ty,
565- mlir::ModuleOp module , bool is_constant,
566- int addr_space,
567- mlir::ImplicitLocOpBuilder& b) {
561+ ml::GlobalOp CreateGlobalOp (mlir::Attribute value,
562+ const std::string& name_prefix,
563+ mlir::ShapedType shaped_ty, mlir::ModuleOp module ,
564+ bool is_constant, int addr_space,
565+ mlir::ImplicitLocOpBuilder& b) {
568566 if (auto elements = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(value)) {
569567 // The lowering to LLVM only works for 1d tensors or those with trailing
570568 // unit dimensions.
@@ -593,19 +591,17 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
593591 packed_data);
594592 }
595593 }
596- auto array_ty =
597- mlir::LLVM::LLVMArrayType::get (llvm_element_type, num_elements);
594+ auto array_ty = ml::LLVMArrayType::get (llvm_element_type, num_elements);
598595 std::string name;
599596 int index = 0 ;
600597 do {
601598 name = absl::StrCat (name_prefix, index);
602599 ++index;
603600 } while (module .lookupSymbol (name));
604601 b.setInsertionPointToStart (module .getBody ());
605- return b.create <mlir::LLVM::GlobalOp>(
606- array_ty, is_constant,
607- /* linkage=*/ mlir::LLVM::Linkage::Private, name, value, /* alignment=*/ 0 ,
608- addr_space);
602+ return b.create <ml::GlobalOp>(array_ty, is_constant,
603+ /* linkage=*/ ml::Linkage::Private, name, value,
604+ /* alignment=*/ 0 , addr_space);
609605}
610606
611607struct RewriteAllocateShared : OpRewritePattern<AllocateSharedOp> {
@@ -623,13 +619,12 @@ struct RewriteAllocateShared : OpRewritePattern<AllocateSharedOp> {
623619 /* is_constant=*/ false , kGPUSharedMemoryAddrSpace , b);
624620
625621 rewriter.setInsertionPoint (op);
626- auto addr = rewriter.create <mlir::LLVM ::AddressOfOp>(op.getLoc (), global);
622+ auto addr = rewriter.create <ml ::AddressOfOp>(op.getLoc (), global);
627623 rewriter.replaceOpWithNewOp <UnrealizedConversionCastOp>(
628624 op, op.getResult ().getType (),
629625 rewriter
630- .create <mlir::LLVM::AddrSpaceCastOp>(
631- op.getLoc (), mlir::LLVM::LLVMPointerType::get (op.getContext ()),
632- addr)
626+ .create <ml::AddrSpaceCastOp>(
627+ op.getLoc (), ml::LLVMPointerType::get (op.getContext ()), addr)
633628 .getResult ());
634629 return success ();
635630 }
@@ -659,13 +654,12 @@ struct RewriteNonScalarConstants : OpRewritePattern<mlir::arith::ConstantOp> {
659654 /* is_constant=*/ true , kGPUGlobalMemoryAddrSpace , b);
660655
661656 rewriter.setInsertionPoint (op);
662- auto addr = rewriter.create <mlir::LLVM ::AddressOfOp>(op.getLoc (), global);
657+ auto addr = rewriter.create <ml ::AddressOfOp>(op.getLoc (), global);
663658 rewriter.replaceOpWithNewOp <UnrealizedConversionCastOp>(
664659 op, op.getResult ().getType (),
665660 rewriter
666- .create <mlir::LLVM::AddrSpaceCastOp>(
667- op.getLoc (), mlir::LLVM::LLVMPointerType::get (op.getContext ()),
668- addr)
661+ .create <ml::AddrSpaceCastOp>(
662+ op.getLoc (), ml::LLVMPointerType::get (op.getContext ()), addr)
669663 .getResult ());
670664 return success ();
671665 }
@@ -727,7 +721,7 @@ Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) {
727721 // direct bitcast from a struct to an int is possible.
728722 Type llvm_input_ty = converter.convertType (value.getType ());
729723 Type llvm_result_ty = converter.convertType (ty);
730- Type ptr_ty = mlir::LLVM ::LLVMPointerType::get (b.getContext ());
724+ Type ptr_ty = ml ::LLVMPointerType::get (b.getContext ());
731725
732726 Value llvm_value =
733727 b.create <UnrealizedConversionCastOp>(llvm_input_ty, value).getResult (0 );
@@ -747,7 +741,19 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
747741
748742 LogicalResult matchAndRewrite (
749743 AtomicRMWOp op, mlir::PatternRewriter& rewriter) const override {
750- if (failed (rewriteAsDirectAtomicRMW (op, rewriter))) {
744+ auto modifier_parameters = GetAtomicModifierParameters (op);
745+ if (modifier_parameters.has_value ()) {
746+ if (mlir::isa<mlir::VectorType>(modifier_parameters->first .getType ()) &&
747+ (IsAMD (*device_description_) ||
748+ !device_description_->cuda_compute_capability ().IsAtLeastHopper ())) {
749+ return rewriter.notifyMatchFailure (
750+ op,
751+ " atomic vectorization currently only supported on Hopper or later" );
752+ }
753+ }
754+
755+ if (!modifier_parameters.has_value () ||
756+ failed (rewriteAsDirectAtomicRMW (op, modifier_parameters, rewriter))) {
751757 rewriteAsAtomicCAS (op, rewriter);
752758 }
753759 rewriter.replaceOp (op, op.getInput ());
@@ -760,11 +766,10 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
760766 // If "computation" is one of this kind, emits code to do that and returns
761767 // true; otherwise, returns false.
762768 LogicalResult rewriteAsDirectAtomicRMW (
763- AtomicRMWOp op, mlir::PatternRewriter& rewriter) const {
764- auto modifier_parameters = GetAtomicModifierParameters (op);
765- if (!modifier_parameters.has_value ()) {
766- return failure ();
767- }
769+ AtomicRMWOp op,
770+ std::optional<std::pair<mlir::Value, ml::AtomicBinOp>>
771+ modifier_parameters,
772+ mlir::PatternRewriter& rewriter) const {
768773 Value modifier_arg = modifier_parameters->first ;
769774 Type element_type = modifier_arg.getType ();
770775 ml::AtomicBinOp atomic_bin_op = modifier_parameters->second ;
@@ -803,7 +808,7 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
803808 : emitNVidiaAtomicFAdd (
804809 loc, modifier_arg, addr, sync_scope,
805810 device_description_->cuda_compute_capability (),
806- rewriter);
811+ rewriter, op );
807812 }
808813 case ml::AtomicBinOp::fmax: {
809814 return rewriteAtomicFMaxAsIntAtomics (loc, modifier_arg, addr,
@@ -817,8 +822,8 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
817822
818823 LogicalResult emitNVidiaAtomicFAdd (
819824 Location loc, Value modifier_arg, Value addr, llvm::StringRef sync_scope,
820- const se::CudaComputeCapability& cuda_compute_capability,
821- OpBuilder& b ) const {
825+ const se::CudaComputeCapability& cuda_compute_capability, OpBuilder& b,
826+ AtomicRMWOp& op ) const {
822827 Type element_type = modifier_arg.getType ();
823828 // "atom.add.f64 requires sm_60 or higher."
824829 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
@@ -831,10 +836,34 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
831836 bool is_supported_f64_atomic =
832837 element_type.isF64 () &&
833838 cuda_compute_capability.IsAtLeast (se::CudaComputeCapability::PASCAL_);
839+ auto vector_type = dyn_cast_or_null<mlir::VectorType>(element_type);
840+ bool is_supported_vector_atomic =
841+ vector_type && vector_type.getElementType ().isF32 () &&
842+ (vector_type.getNumElements () == 2 ||
843+ vector_type.getNumElements () == 4 ) &&
844+ cuda_compute_capability.IsAtLeast (se::CudaComputeCapability::HOPPER);
834845 if (!element_type.isF32 () && !is_supported_f16_atomic &&
835- !is_supported_bf16_atomic && !is_supported_f64_atomic) {
846+ !is_supported_bf16_atomic && !is_supported_f64_atomic &&
847+ !is_supported_vector_atomic) {
836848 return failure ();
837849 }
850+
851+ // TODO(389862360): Currently vectorized AtomicRMWOp lowers incorrectly to
852+ // PTX due to a bug in NVPTX. We scalarize it for now.
853+ if (is_supported_vector_atomic) {
854+ mlir::ImplicitLocOpBuilder imp_b (loc, b);
855+ auto base = GetLinearIndex (op.getIndices (), imp_b);
856+ for (int i = 0 ; i < vector_type.getNumElements (); ++i) {
857+ auto modifier_arg_i = imp_b.create <vector::ExtractOp>(modifier_arg, i);
858+ auto offset = imp_b.create <ml::ConstantOp>(base.getType (), i);
859+ auto addr_i = imp_b.create <ml::AddOp>(base, offset);
860+ Value gep = CreateGep (op.getInput (), addr_i, imp_b);
861+ imp_b.create <ml::AtomicRMWOp>(ml::AtomicBinOp::fadd, gep,
862+ modifier_arg_i,
863+ ml::AtomicOrdering::seq_cst, sync_scope);
864+ }
865+ return success ();
866+ }
838867 b.create <ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
839868 ml::AtomicOrdering::seq_cst, sync_scope);
840869 return success ();
@@ -1116,17 +1145,17 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
11161145 return ;
11171146 }
11181147
1119- getOperation ()->walk ([this ](mlir::LLVM ::LoadOp load) {
1148+ getOperation ()->walk ([this ](ml ::LoadOp load) {
11201149 Value addr = load.getAddr ();
1121- while (auto gep = addr.getDefiningOp <mlir::LLVM ::GEPOp>()) {
1150+ while (auto gep = addr.getDefiningOp <ml ::GEPOp>()) {
11221151 addr = gep.getBase ();
11231152 }
11241153 while (auto cast = addr.getDefiningOp <UnrealizedConversionCastOp>()) {
11251154 addr = cast.getOperand (0 );
11261155 }
1127- if (addr.getDefiningOp <mlir::LLVM ::AddrSpaceCastOp>() ||
1128- addr.getDefiningOp <mlir::LLVM ::AddressOfOp>() ||
1129- addr.getDefiningOp <mlir::LLVM ::AllocaOp>()) {
1156+ if (addr.getDefiningOp <ml ::AddrSpaceCastOp>() ||
1157+ addr.getDefiningOp <ml ::AddressOfOp>() ||
1158+ addr.getDefiningOp <ml ::AllocaOp>()) {
11301159 // Shared memory, global constant or temporary - no need to annotate
11311160 // anything.
11321161 return ;
0 commit comments