Skip to content

Commit c533808

Browse files
MoerafaatGoogle-ML-Automation
authored andcommitted
Adding vectorization support for atomic_rmw.
Currently only supports f32 vectors of size 2 or 4. There is a bug in LLVM when lowering to PTX that lowers the vectorized atomic RMW incorrectly. For now, we scalarize, so effectively this is disabled. This should be followed-up with a direct lowering to PTX as a work-around. PiperOrigin-RevId: 715402916
1 parent c729ff7 commit c533808

File tree

5 files changed

+304
-75
lines changed

5 files changed

+304
-75
lines changed

xla/backends/gpu/codegen/transforms/lower_tensors.cc

Lines changed: 91 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ namespace {
7373
#define GEN_PASS_DEF_LOWERTENSORSPASS
7474
#include "xla/backends/gpu/codegen/transforms/passes.h.inc"
7575

76+
using llvm::dyn_cast_or_null;
7677
using mlir::failure;
7778
using mlir::Location;
7879
using mlir::LogicalResult;
@@ -93,6 +94,7 @@ using mlir::ValueRange;
9394
namespace arith = ::mlir::arith;
9495
namespace scf = ::mlir::scf;
9596
namespace ml = ::mlir::LLVM;
97+
namespace vector = ::mlir::vector;
9698

9799
bool IsAMD(const se::DeviceDescription& device_description) {
98100
return std::holds_alternative<se::RocmComputeCapability>(
@@ -114,7 +116,7 @@ Value GetDestinationBuffer(Value dest) {
114116
dest.getDefiningOp<AllocateSharedOp>()) {
115117
break;
116118
} else if (auto transfer_write =
117-
dest.getDefiningOp<mlir::vector::TransferWriteOp>()) {
119+
dest.getDefiningOp<vector::TransferWriteOp>()) {
118120
dest = transfer_write.getSource();
119121
} else {
120122
dest.getDefiningOp()->emitOpError("unsupported dest type");
@@ -168,7 +170,7 @@ struct RewriteFunctionSignatures : OpRewritePattern<mlir::func::FuncOp> {
168170
auto cast = rewriter.create<UnrealizedConversionCastOp>(
169171
op.getLoc(), operand, op.getArgument(index));
170172
op.getArgument(index).replaceAllUsesExcept(cast.getResult(0), cast);
171-
operand = mlir::LLVM::LLVMPointerType::get(op.getContext());
173+
operand = ml::LLVMPointerType::get(op.getContext());
172174
}
173175
}
174176

@@ -188,7 +190,7 @@ Value GetPtr(Value value) {
188190
}
189191
if (auto cast = value.getDefiningOp<UnrealizedConversionCastOp>()) {
190192
if (cast.getNumOperands() == 1 && cast.getNumResults() == 1 &&
191-
mlir::isa<mlir::LLVM::LLVMPointerType>(cast.getOperand(0).getType())) {
193+
mlir::isa<ml::LLVMPointerType>(cast.getOperand(0).getType())) {
192194
return cast.getOperand(0);
193195
}
194196
}
@@ -294,25 +296,25 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
294296
return {i8_index, is_low_nibble};
295297
}
296298

297-
mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
298-
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
299+
ml::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
300+
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
299301
Type element_type = tensor.getType().getElementType();
300302
if (element_type == b.getI4Type()) {
301303
element_type = b.getI8Type();
302304
}
303-
auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext());
305+
auto ptr = ml::LLVMPointerType::get(b.getContext());
304306
auto tensor_ptr =
305307
b.create<UnrealizedConversionCastOp>(ptr, tensor).getResult(0);
306308
mlir::LLVMTypeConverter converter(b.getContext());
307309
auto llvm_element_type = converter.convertType(element_type);
308-
auto gep = b.create<mlir::LLVM::GEPOp>(ptr, llvm_element_type, tensor_ptr,
309-
linear_index);
310+
auto gep =
311+
b.create<ml::GEPOp>(ptr, llvm_element_type, tensor_ptr, linear_index);
310312
gep.setInbounds(true);
311313
return gep;
312314
}
313315

314-
mlir::LLVM::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
315-
ValueRange indices, mlir::ImplicitLocOpBuilder& b) {
316+
ml::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
317+
ValueRange indices, mlir::ImplicitLocOpBuilder& b) {
316318
return CreateGep(tensor, GetLinearIndex(indices, b), b);
317319
}
318320

@@ -333,8 +335,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
333335

334336
auto gep = CreateGep(op.getTensor(), linear_index, b);
335337
auto load =
336-
rewriter
337-
.create<mlir::LLVM::LoadOp>(gep.getLoc(), gep.getElemType(), gep)
338+
rewriter.create<ml::LoadOp>(gep.getLoc(), gep.getElemType(), gep)
338339
.getResult();
339340

340341
if (is_low_nibble) {
@@ -359,19 +360,19 @@ Value PermutePairsInVector(Value vector, mlir::ImplicitLocOpBuilder& b) {
359360
int size = ty.getNumElements();
360361
Value result = vector;
361362
for (int i = 0; i < size; i += 2) {
362-
auto v0 = b.create<mlir::vector::ExtractOp>(vector, i);
363-
auto v1 = b.create<mlir::vector::ExtractOp>(vector, i + 1);
364-
result = b.create<mlir::vector::InsertOp>(v1, result, i);
365-
result = b.create<mlir::vector::InsertOp>(v0, result, i + 1);
363+
auto v0 = b.create<vector::ExtractOp>(vector, i);
364+
auto v1 = b.create<vector::ExtractOp>(vector, i + 1);
365+
result = b.create<vector::InsertOp>(v1, result, i);
366+
result = b.create<vector::InsertOp>(v0, result, i + 1);
366367
}
367368
return result;
368369
}
369370

370-
struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
371+
struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
371372
using OpRewritePattern::OpRewritePattern;
372373

373374
LogicalResult matchAndRewrite(
374-
mlir::vector::TransferReadOp op,
375+
vector::TransferReadOp op,
375376
mlir::PatternRewriter& rewriter) const override {
376377
assert(IsSupportedTransfer(op));
377378

@@ -394,8 +395,7 @@ struct RewriteTransferRead : OpRewritePattern<mlir::vector::TransferReadOp> {
394395

395396
mlir::LLVMTypeConverter converter(b.getContext());
396397
auto llvm_vector_type = converter.convertType(vector_type);
397-
auto loaded =
398-
b.create<mlir::LLVM::LoadOp>(llvm_vector_type, gep).getResult();
398+
auto loaded = b.create<ml::LoadOp>(llvm_vector_type, gep).getResult();
399399

400400
if (source.getType().getElementType().isInteger(1)) {
401401
Value zero = b.create<mlir::arith::ConstantOp>(
@@ -484,7 +484,7 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
484484
scalar_value =
485485
b.create<UnrealizedConversionCastOp>(llvm_type, scalar_value)
486486
.getResult(0);
487-
b.create<mlir::LLVM::StoreOp>(scalar_value, gep);
487+
b.create<ml::StoreOp>(scalar_value, gep);
488488
op.replaceAllUsesWith(op.getDest());
489489
}
490490

@@ -493,11 +493,11 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
493493
}
494494
};
495495

496-
struct RewriteTransferWrite : OpRewritePattern<mlir::vector::TransferWriteOp> {
496+
struct RewriteTransferWrite : OpRewritePattern<vector::TransferWriteOp> {
497497
using OpRewritePattern::OpRewritePattern;
498498

499499
LogicalResult matchAndRewrite(
500-
mlir::vector::TransferWriteOp op,
500+
vector::TransferWriteOp op,
501501
mlir::PatternRewriter& rewriter) const override {
502502
assert(IsSupportedTransfer(op));
503503
Value dest = GetDestinationBuffer(op.getSource());
@@ -526,7 +526,7 @@ struct RewriteTransferWrite : OpRewritePattern<mlir::vector::TransferWriteOp> {
526526
auto llvm_type = converter.convertType(vector_value.getType());
527527
vector_value = b.create<UnrealizedConversionCastOp>(llvm_type, vector_value)
528528
.getResult(0);
529-
b.create<mlir::LLVM::StoreOp>(vector_value, gep);
529+
b.create<ml::StoreOp>(vector_value, gep);
530530

531531
rewriter.replaceOp(op, mlir::ValueRange{op.getSource()});
532532
return success();
@@ -550,21 +550,19 @@ struct RewriteCall : OpRewritePattern<mlir::func::CallOp> {
550550
index,
551551
rewriter
552552
.create<UnrealizedConversionCastOp>(
553-
op.getLoc(),
554-
mlir::LLVM::LLVMPointerType::get(op.getContext()), arg)
553+
op.getLoc(), ml::LLVMPointerType::get(op.getContext()), arg)
555554
.getResult(0));
556555
}
557556
}
558557
return success();
559558
}
560559
};
561560

562-
mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
563-
const std::string& name_prefix,
564-
mlir::ShapedType shaped_ty,
565-
mlir::ModuleOp module, bool is_constant,
566-
int addr_space,
567-
mlir::ImplicitLocOpBuilder& b) {
561+
ml::GlobalOp CreateGlobalOp(mlir::Attribute value,
562+
const std::string& name_prefix,
563+
mlir::ShapedType shaped_ty, mlir::ModuleOp module,
564+
bool is_constant, int addr_space,
565+
mlir::ImplicitLocOpBuilder& b) {
568566
if (auto elements = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(value)) {
569567
// The lowering to LLVM only works for 1d tensors or those with trailing
570568
// unit dimensions.
@@ -593,19 +591,17 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
593591
packed_data);
594592
}
595593
}
596-
auto array_ty =
597-
mlir::LLVM::LLVMArrayType::get(llvm_element_type, num_elements);
594+
auto array_ty = ml::LLVMArrayType::get(llvm_element_type, num_elements);
598595
std::string name;
599596
int index = 0;
600597
do {
601598
name = absl::StrCat(name_prefix, index);
602599
++index;
603600
} while (module.lookupSymbol(name));
604601
b.setInsertionPointToStart(module.getBody());
605-
return b.create<mlir::LLVM::GlobalOp>(
606-
array_ty, is_constant,
607-
/*linkage=*/mlir::LLVM::Linkage::Private, name, value, /*alignment=*/0,
608-
addr_space);
602+
return b.create<ml::GlobalOp>(array_ty, is_constant,
603+
/*linkage=*/ml::Linkage::Private, name, value,
604+
/*alignment=*/0, addr_space);
609605
}
610606

611607
struct RewriteAllocateShared : OpRewritePattern<AllocateSharedOp> {
@@ -623,13 +619,12 @@ struct RewriteAllocateShared : OpRewritePattern<AllocateSharedOp> {
623619
/*is_constant=*/false, kGPUSharedMemoryAddrSpace, b);
624620

625621
rewriter.setInsertionPoint(op);
626-
auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
622+
auto addr = rewriter.create<ml::AddressOfOp>(op.getLoc(), global);
627623
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
628624
op, op.getResult().getType(),
629625
rewriter
630-
.create<mlir::LLVM::AddrSpaceCastOp>(
631-
op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
632-
addr)
626+
.create<ml::AddrSpaceCastOp>(
627+
op.getLoc(), ml::LLVMPointerType::get(op.getContext()), addr)
633628
.getResult());
634629
return success();
635630
}
@@ -659,13 +654,12 @@ struct RewriteNonScalarConstants : OpRewritePattern<mlir::arith::ConstantOp> {
659654
/*is_constant=*/true, kGPUGlobalMemoryAddrSpace, b);
660655

661656
rewriter.setInsertionPoint(op);
662-
auto addr = rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), global);
657+
auto addr = rewriter.create<ml::AddressOfOp>(op.getLoc(), global);
663658
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
664659
op, op.getResult().getType(),
665660
rewriter
666-
.create<mlir::LLVM::AddrSpaceCastOp>(
667-
op.getLoc(), mlir::LLVM::LLVMPointerType::get(op.getContext()),
668-
addr)
661+
.create<ml::AddrSpaceCastOp>(
662+
op.getLoc(), ml::LLVMPointerType::get(op.getContext()), addr)
669663
.getResult());
670664
return success();
671665
}
@@ -727,7 +721,7 @@ Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) {
727721
// direct bitcast from a struct to an int is possible.
728722
Type llvm_input_ty = converter.convertType(value.getType());
729723
Type llvm_result_ty = converter.convertType(ty);
730-
Type ptr_ty = mlir::LLVM::LLVMPointerType::get(b.getContext());
724+
Type ptr_ty = ml::LLVMPointerType::get(b.getContext());
731725

732726
Value llvm_value =
733727
b.create<UnrealizedConversionCastOp>(llvm_input_ty, value).getResult(0);
@@ -747,7 +741,19 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
747741

748742
LogicalResult matchAndRewrite(
749743
AtomicRMWOp op, mlir::PatternRewriter& rewriter) const override {
750-
if (failed(rewriteAsDirectAtomicRMW(op, rewriter))) {
744+
auto modifier_parameters = GetAtomicModifierParameters(op);
745+
if (modifier_parameters.has_value()) {
746+
if (mlir::isa<mlir::VectorType>(modifier_parameters->first.getType()) &&
747+
(IsAMD(*device_description_) ||
748+
!device_description_->cuda_compute_capability().IsAtLeastHopper())) {
749+
return rewriter.notifyMatchFailure(
750+
op,
751+
"atomic vectorization currently only supported on Hopper or later");
752+
}
753+
}
754+
755+
if (!modifier_parameters.has_value() ||
756+
failed(rewriteAsDirectAtomicRMW(op, modifier_parameters, rewriter))) {
751757
rewriteAsAtomicCAS(op, rewriter);
752758
}
753759
rewriter.replaceOp(op, op.getInput());
@@ -760,11 +766,10 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
760766
// If "computation" is one of this kind, emits code to do that and returns
761767
// true; otherwise, returns false.
762768
LogicalResult rewriteAsDirectAtomicRMW(
763-
AtomicRMWOp op, mlir::PatternRewriter& rewriter) const {
764-
auto modifier_parameters = GetAtomicModifierParameters(op);
765-
if (!modifier_parameters.has_value()) {
766-
return failure();
767-
}
769+
AtomicRMWOp op,
770+
std::optional<std::pair<mlir::Value, ml::AtomicBinOp>>
771+
modifier_parameters,
772+
mlir::PatternRewriter& rewriter) const {
768773
Value modifier_arg = modifier_parameters->first;
769774
Type element_type = modifier_arg.getType();
770775
ml::AtomicBinOp atomic_bin_op = modifier_parameters->second;
@@ -803,7 +808,7 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
803808
: emitNVidiaAtomicFAdd(
804809
loc, modifier_arg, addr, sync_scope,
805810
device_description_->cuda_compute_capability(),
806-
rewriter);
811+
rewriter, op);
807812
}
808813
case ml::AtomicBinOp::fmax: {
809814
return rewriteAtomicFMaxAsIntAtomics(loc, modifier_arg, addr,
@@ -817,8 +822,8 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
817822

818823
LogicalResult emitNVidiaAtomicFAdd(
819824
Location loc, Value modifier_arg, Value addr, llvm::StringRef sync_scope,
820-
const se::CudaComputeCapability& cuda_compute_capability,
821-
OpBuilder& b) const {
825+
const se::CudaComputeCapability& cuda_compute_capability, OpBuilder& b,
826+
AtomicRMWOp& op) const {
822827
Type element_type = modifier_arg.getType();
823828
// "atom.add.f64 requires sm_60 or higher."
824829
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
@@ -831,10 +836,34 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
831836
bool is_supported_f64_atomic =
832837
element_type.isF64() &&
833838
cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::PASCAL_);
839+
auto vector_type = dyn_cast_or_null<mlir::VectorType>(element_type);
840+
bool is_supported_vector_atomic =
841+
vector_type && vector_type.getElementType().isF32() &&
842+
(vector_type.getNumElements() == 2 ||
843+
vector_type.getNumElements() == 4) &&
844+
cuda_compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER);
834845
if (!element_type.isF32() && !is_supported_f16_atomic &&
835-
!is_supported_bf16_atomic && !is_supported_f64_atomic) {
846+
!is_supported_bf16_atomic && !is_supported_f64_atomic &&
847+
!is_supported_vector_atomic) {
836848
return failure();
837849
}
850+
851+
// TODO(389862360): Currently vectorized AtomicRMWOp lowers incorrectly to
852+
// PTX due to a bug in NVPTX. We scalarize it for now.
853+
if (is_supported_vector_atomic) {
854+
mlir::ImplicitLocOpBuilder imp_b(loc, b);
855+
auto base = GetLinearIndex(op.getIndices(), imp_b);
856+
for (int i = 0; i < vector_type.getNumElements(); ++i) {
857+
auto modifier_arg_i = imp_b.create<vector::ExtractOp>(modifier_arg, i);
858+
auto offset = imp_b.create<ml::ConstantOp>(base.getType(), i);
859+
auto addr_i = imp_b.create<ml::AddOp>(base, offset);
860+
Value gep = CreateGep(op.getInput(), addr_i, imp_b);
861+
imp_b.create<ml::AtomicRMWOp>(ml::AtomicBinOp::fadd, gep,
862+
modifier_arg_i,
863+
ml::AtomicOrdering::seq_cst, sync_scope);
864+
}
865+
return success();
866+
}
838867
b.create<ml::AtomicRMWOp>(loc, ml::AtomicBinOp::fadd, addr, modifier_arg,
839868
ml::AtomicOrdering::seq_cst, sync_scope);
840869
return success();
@@ -1116,17 +1145,17 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
11161145
return;
11171146
}
11181147

1119-
getOperation()->walk([this](mlir::LLVM::LoadOp load) {
1148+
getOperation()->walk([this](ml::LoadOp load) {
11201149
Value addr = load.getAddr();
1121-
while (auto gep = addr.getDefiningOp<mlir::LLVM::GEPOp>()) {
1150+
while (auto gep = addr.getDefiningOp<ml::GEPOp>()) {
11221151
addr = gep.getBase();
11231152
}
11241153
while (auto cast = addr.getDefiningOp<UnrealizedConversionCastOp>()) {
11251154
addr = cast.getOperand(0);
11261155
}
1127-
if (addr.getDefiningOp<mlir::LLVM::AddrSpaceCastOp>() ||
1128-
addr.getDefiningOp<mlir::LLVM::AddressOfOp>() ||
1129-
addr.getDefiningOp<mlir::LLVM::AllocaOp>()) {
1156+
if (addr.getDefiningOp<ml::AddrSpaceCastOp>() ||
1157+
addr.getDefiningOp<ml::AddressOfOp>() ||
1158+
addr.getDefiningOp<ml::AllocaOp>()) {
11301159
// Shared memory, global constant or temporary - no need to annotate
11311160
// anything.
11321161
return;

xla/backends/gpu/codegen/transforms/passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> {
8181
"mlir::tensor::TensorDialect",
8282
"xla::gpu::XlaGpuDialect",
8383
"xla::XlaDialect",
84+
"mlir::vector::VectorDialect",
8485
];
8586
let options = [
8687
Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"",

0 commit comments

Comments
 (0)