Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);

} // namespace detail

/// Decomposes a `src` value into a set of values of type `dstType` through
/// series of bitcasts and vector ops. Src and dst types are expected to be int
/// or float types or vector types of them.
SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
Type dstType);

/// Composes a set of `src` values into a single value of type `dstType` through
/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
/// function is used to combine multiple values into a single value.
Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
Type dstType);
} // namespace LLVM

/// Base class for operation conversions targeting the LLVM IR dialect. It
Expand Down
33 changes: 33 additions & 0 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,38 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
}
};

struct AMDGPUSwizzleBitModeLowering
: public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to this PR, but I think it'd be more intuitive if the op name isdsSwizzleOp, and we can have the QDMode and the BitMode as a attribute of this op. Additionally we can also reference https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QDMode and BitMode variants have a different offsets format, my idea was to have them as separate ops in AMDGPU, so user won't need to bother with offset bitpacking even if they are both lowered to ds_swizzle eventually.

using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, src, i32);
unsigned andMask = op.getAndMask();
unsigned orMask = op.getOrMask();
unsigned xorMask = op.getXorMask();

// bit 15 is 0 for the BitMode swizzle.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we link to the ISA manual or llvm intrinsics here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
Value maskValue = createI32Constant(rewriter, loc, mask);
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res =
rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
swizzled.emplace_back(res);
}

Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
rewriter.replaceOp(op, result);
return success();
}
};

struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
Expand Down Expand Up @@ -1444,4 +1476,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
92 changes: 92 additions & 0 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,95 @@ LogicalResult LLVM::detail::oneToOneRewrite(
rewriter.replaceOp(op, results);
return success();
}

static unsigned getBitWidth(Type type) {
if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();

auto vec = cast<VectorType>(type);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to assert this is not a scalable vector or bail out in some other way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, updated verifier to reject scalable vectors, added assert here.

return vec.getNumElements() * getBitWidth(vec.getElementType());
}

static Value createI32Constant(OpBuilder &builder, Location loc,
int32_t value) {
Type i32 = builder.getI32Type();
return builder.create<LLVM::ConstantOp>(loc, i32, value);
}

SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
Value src, Type dstType) {
Type srcType = src.getType();
if (srcType == dstType)
return {src};

unsigned srcBitWidth = getBitWidth(srcType);
unsigned dstBitWidth = getBitWidth(dstType);
if (srcBitWidth == dstBitWidth) {
Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
return {cast};
}

if (dstBitWidth > srcBitWidth) {
auto smallerInt = builder.getIntegerType(srcBitWidth);
if (srcType != smallerInt)
src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);

auto largerInt = builder.getIntegerType(dstBitWidth);
Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
return {res};
}
assert(srcBitWidth % dstBitWidth == 0 &&
"src bit width must be a multiple of dst bit width");
int64_t numElements = srcBitWidth / dstBitWidth;
auto vecType = VectorType::get(numElements, dstType);

src = builder.create<LLVM::BitcastOp>(loc, vecType, src);

SmallVector<Value> res;
for (auto i : llvm::seq<int64_t>(0, numElements)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd use a plain for loop here...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like llvm:seq more than C loops. But switched to 1-arg version.

Value idx = createI32Constant(builder, loc, i);
Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
res.emplace_back(elem);
}

return res;
}

Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
Type dstType) {
assert(!src.empty() && "src range must not be empty");
if (src.size() == 1) {
Value res = src.front();
if (res.getType() == dstType)
return res;

unsigned srcBitWidth = getBitWidth(res.getType());
unsigned dstBitWidth = getBitWidth(dstType);
if (dstBitWidth < srcBitWidth) {
auto largerInt = builder.getIntegerType(srcBitWidth);
if (res.getType() != largerInt)
res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);

auto smallerInt = builder.getIntegerType(dstBitWidth);
res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
}

if (res.getType() != dstType)
res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool, didn't know llvm::bitcast can transform from scalar to vector types with same bitwidth


return res;
}

int64_t numElements = src.size();
auto srcType = VectorType::get(numElements, src.front().getType());
Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
for (auto &&[i, elem] : llvm::enumerate(src)) {
Value idx = createI32Constant(builder, loc, i);
res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
}

if (res.getType() != dstType)
res = builder.create<LLVM::BitcastOp>(loc, dstType, res);

return res;
}
75 changes: 75 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s

// CHECK-LABEL: func @test_swizzle_i32
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
// CHECK: return %[[RES]] : i32
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
return %0 : i32
}

// CHECK-LABEL: func @test_swizzle_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32)
func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
return %0 : f32
}

// CHECK-LABEL: func @test_swizzle_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16)
func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
return %0 : f16
}

// CHECK-LABEL: func @test_swizzle_2xi32
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[V3]] : vector<2xi32>
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
return %0 : vector<2xi32>
}

// CHECK-LABEL: func @test_swizzle_4xf16
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
// CHECK: %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
// CHECK: return %[[CAST2]] : vector<4xf16>
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
return %0 : vector<4xf16>
}
Loading