Skip to content
Merged
129 changes: 129 additions & 0 deletions test/Conversion/intel/sub-group-transpose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,132 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16
tt.return %0 : tensor<64x16x4xf32, #blocked1>
}
}

// -----

// Test transposition with 32 elements per work-item.

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @test(
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> {
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
%0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1>
tt.return %0 : tensor<32x16xf32, #blocked1>
}
}

// -----

// Test transposition with 32 elements per work-item with a different layout.

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @test(
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> {
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
%0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1>
tt.return %0 : tensor<16x32xf32, #blocked1>
}
}

// -----

// Test transposition with 32 elements per work-item and two warps in each dimension.

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [0, 1]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @test(
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> {
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
%0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1>
tt.return %0 : tensor<32x64xf32, #blocked1>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -462,17 +462,21 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// Expected conversion is:
// - register=1 -> (0, 1)
// ...
// - register=i -> (0, 2**(i-1))
// - register=2**i -> (0, 2**i)
// ...
// - register=N -> (0, 2**(N-1))
// - register=M -> (0, 2**M)
// ...
// - register=2**k -> (2**k, 0)
// ...
// - register=N -> (2**N, 0)
// - lane=1 -> (0, 1)
// ...
// - lane=j -> (2**(j-1), 0)
// - lane=2**j -> (2**j, 0)
// ...
// lane=M -> (2**(M-1), 0)
// where out dims are: [register (size 2**(N-1)), lane (size 2**(M-1))]
// lane=2**M -> (2**M, 0)
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
//
// With N = M.
// With N >= M.
const auto buildBasis = [&](int32_t size, std::size_t index) {
std::vector<std::vector<int32_t>> basis;
std::vector<int32_t> curr(2);
Expand All @@ -482,13 +486,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
return basis;
};

constexpr std::size_t laneIndex = 0;
constexpr std::size_t registerIndex = 1;
int32_t size = conversion->getInDimSize(kLane);
int32_t laneSize = conversion->getInDimSize(kLane);
std::vector<std::vector<int32_t>> registerBases =
buildBasis(laneSize, registerIndex);
{
// Populate register bases for N > M.
std::vector<int32_t> base(2);
for (int32_t i = laneSize,
registerSize = conversion->getInDimSize(kRegister);
i < registerSize; i *= 2) {
base[laneIndex] = i;
registerBases.push_back(base);
}
}
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
bases{{{kRegister, buildBasis(size, registerIndex)},
{kLane, buildBasis(size, laneIndex)}}};
bases{{{kRegister, std::move(registerBases)},
{kLane, buildBasis(laneSize, laneIndex)}}};
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
return conversion == LinearLayout(bases, outDimNames);
}
Expand Down Expand Up @@ -739,11 +754,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
OpAdaptor adaptor) const {
auto srcType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
ArrayRef<Type> body = srcType.getBody();
// TODO: Support more configurations.
auto mod = op->getParentOfType<ModuleOp>();
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
if (body.size() != threadsPerWarp)
return false;
return TypeSwitch<Type, bool>(body.front())
.Case([this](FloatType floatTy) {
// Support via bitcasting to integer type.
Expand Down Expand Up @@ -888,12 +898,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}

SmallVector<Value>
unwrapFromVector(Location loc, Value vec,
ConversionPatternRewriter &rewriter) const {
unwrapFromVectors(Location loc, ArrayRef<Value> vecs,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> res;
for (unsigned i = 0, n = cast<VectorType>(vec.getType()).getShape()[0];
i < n; ++i)
res.push_back(extract_element(vec, i32_val(i)));
for (Value vec : vecs) {
for (unsigned i = 0, n = cast<VectorType>(vec.getType()).getShape()[0];
i < n; ++i)
res.push_back(extract_element(vec, i32_val(i)));
}
return res;
}

Expand All @@ -908,6 +920,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
loc, rewriter, targetInfo, &*rewriter.getInsertionPoint());
Type ptrType = smemBase.getType();

int numElements = inVals.size();
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
int offset = threadsPerWarp;
Type offsetType = getTypeConverter()->getIndexType();
Expand All @@ -922,7 +935,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
Value wiStride =
rewriter.create<LLVM::ConstantOp>(loc, offsetType, threadsPerWarp);
Value sgStride = rewriter.create<LLVM::ConstantOp>(
loc, offsetType, threadsPerWarp * threadsPerWarp);
loc, offsetType, threadsPerWarp * numElements);
Value subGroupOffset = mul(sgStride, subGroupId);
Type elementType = opType.getElementType();
Value subGroupBasePtr = gep(ptrType, elementType, smemBase,
Expand All @@ -939,13 +952,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}

// Load from matrix, non-trasposed.
// As per SIMD block semantics, we have stored the elements in a matrix of
// `Nxsub_group_size` size, so we need to load back in blocks of
// `sub_group_size` (`N/sub_group_size` loads).
Value workItemOffset = mul(wiStride, subGroupLocalId);
Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr,
ValueRange{workItemOffset}, /*inbounds=*/true);
Value transposedVec =
load(vec_ty(opType.getElementType(), inVals.size()), workItemBasePtr);

return unwrapFromVector(loc, transposedVec, rewriter);
SmallVector<Value> transposedVecs;
Type loadTy = vec_ty(opType.getElementType(), threadsPerWarp);
for (std::size_t i = 0, n = inVals.size(); i < n; i += threadsPerWarp) {
transposedVecs.push_back(load(loadTy, workItemBasePtr));
workItemBasePtr = gep(ptrType, loadTy, workItemBasePtr,
ArrayRef<LLVM::GEPArg>{offset}, /*inbounds=*/true);
}
return unwrapFromVectors(loc, transposedVecs, rewriter);
}

LogicalResult
Expand Down