Skip to content

Commit 0514102

Browse files
authored
[TritonIntelGPUToLLVM] Extend sub-group transposition support (#2521)
Extend sub-group transposition support allowing `N*sub_group_size` elements per thread. As per block load semantics (matrix of `sub_group_size` columns), we need `N` vector loads to load the transposed matrix from local memory. --------- Signed-off-by: victor-eds <[email protected]>
1 parent 6647f59 commit 0514102

File tree

2 files changed

+174
-25
lines changed

2 files changed

+174
-25
lines changed

test/Conversion/intel/sub-group-transpose.mlir

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,132 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16
297297
tt.return %0 : tensor<64x16x4xf32, #blocked1>
298298
}
299299
}
300+
301+
// -----
302+
303+
// Test transposition with 32 elements per work-item.
304+
305+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
306+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
307+
308+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
309+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
310+
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
311+
tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> {
312+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
313+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
314+
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
315+
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
316+
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
317+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
318+
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
319+
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
320+
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
321+
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
322+
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
323+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
324+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
325+
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
326+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
327+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
328+
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
329+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
330+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
331+
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
332+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
333+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
334+
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
335+
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
336+
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
337+
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
338+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
339+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1>
340+
tt.return %0 : tensor<32x16xf32, #blocked1>
341+
}
342+
}
343+
344+
// -----
345+
346+
// Test transposition with 32 elements per work-item with a different layout.
347+
348+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
349+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
350+
351+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
352+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
353+
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
354+
tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> {
355+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
356+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
357+
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
358+
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
359+
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
360+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
361+
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
362+
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
363+
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
364+
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
365+
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
366+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
367+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
368+
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
369+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
370+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
371+
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
372+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
373+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
374+
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
375+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
376+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
377+
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
378+
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
379+
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
380+
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
381+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
382+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1>
383+
tt.return %0 : tensor<16x32xf32, #blocked1>
384+
}
385+
}
386+
387+
// -----
388+
389+
// Test transposition with 32 elements per work-item and two warps in each dimension.
390+
391+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}>
392+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
393+
394+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
395+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
396+
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
397+
tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> {
398+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
399+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
400+
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
401+
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
402+
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
403+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
404+
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
405+
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
406+
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
407+
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
408+
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
409+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
410+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
411+
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
412+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
413+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
414+
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
415+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
416+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
417+
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
418+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
419+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
420+
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
421+
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
422+
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
423+
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
424+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
425+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1>
426+
tt.return %0 : tensor<32x64xf32, #blocked1>
427+
}
428+
}

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -462,17 +462,21 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
462462
// Expected conversion is:
463463
// - register=1 -> (0, 1)
464464
// ...
465-
// - register=i -> (0, 2**(i-1))
465+
// - register=2**i -> (0, 2**i)
466466
// ...
467-
// - register=N -> (0, 2**(N-1))
467+
// - register=M -> (0, 2**M)
468+
// ...
469+
// - register=2**k -> (2**k, 0)
470+
// ...
471+
// - register=N -> (2**N, 0)
468472
// - lane=1 -> (0, 1)
469473
// ...
470-
// - lane=j -> (2**(j-1), 0)
474+
// - lane=2**j -> (2**j, 0)
471475
// ...
472-
// lane=M -> (2**(M-1), 0)
473-
// where out dims are: [register (size 2**(N-1)), lane (size 2**(M-1))]
476+
// lane=2**M -> (2**M, 0)
477+
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
474478
//
475-
// With N = M.
479+
// With N >= M.
476480
const auto buildBasis = [&](int32_t size, std::size_t index) {
477481
std::vector<std::vector<int32_t>> basis;
478482
std::vector<int32_t> curr(2);
@@ -482,13 +486,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
482486
}
483487
return basis;
484488
};
485-
486489
constexpr std::size_t laneIndex = 0;
487490
constexpr std::size_t registerIndex = 1;
488-
int32_t size = conversion->getInDimSize(kLane);
491+
int32_t laneSize = conversion->getInDimSize(kLane);
492+
std::vector<std::vector<int32_t>> registerBases =
493+
buildBasis(laneSize, registerIndex);
494+
{
495+
// Populate register bases for N > M.
496+
std::vector<int32_t> base(2);
497+
for (int32_t i = laneSize,
498+
registerSize = conversion->getInDimSize(kRegister);
499+
i < registerSize; i *= 2) {
500+
base[laneIndex] = i;
501+
registerBases.push_back(base);
502+
}
503+
}
489504
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
490-
bases{{{kRegister, buildBasis(size, registerIndex)},
491-
{kLane, buildBasis(size, laneIndex)}}};
505+
bases{{{kRegister, std::move(registerBases)},
506+
{kLane, buildBasis(laneSize, laneIndex)}}};
492507
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
493508
return conversion == LinearLayout(bases, outDimNames);
494509
}
@@ -739,11 +754,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
739754
OpAdaptor adaptor) const {
740755
auto srcType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
741756
ArrayRef<Type> body = srcType.getBody();
742-
// TODO: Support more configurations.
743-
auto mod = op->getParentOfType<ModuleOp>();
744-
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
745-
if (body.size() != threadsPerWarp)
746-
return false;
747757
return TypeSwitch<Type, bool>(body.front())
748758
.Case([this](FloatType floatTy) {
749759
// Support via bitcasting to integer type.
@@ -888,12 +898,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
888898
}
889899

890900
SmallVector<Value>
891-
unwrapFromVector(Location loc, Value vec,
892-
ConversionPatternRewriter &rewriter) const {
901+
unwrapFromVectors(Location loc, ArrayRef<Value> vecs,
902+
ConversionPatternRewriter &rewriter) const {
893903
SmallVector<Value> res;
894-
for (unsigned i = 0, n = cast<VectorType>(vec.getType()).getShape()[0];
895-
i < n; ++i)
896-
res.push_back(extract_element(vec, i32_val(i)));
904+
for (Value vec : vecs) {
905+
for (unsigned i = 0, n = cast<VectorType>(vec.getType()).getShape()[0];
906+
i < n; ++i)
907+
res.push_back(extract_element(vec, i32_val(i)));
908+
}
897909
return res;
898910
}
899911

@@ -908,6 +920,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
908920
loc, rewriter, targetInfo, &*rewriter.getInsertionPoint());
909921
Type ptrType = smemBase.getType();
910922

923+
int numElements = inVals.size();
911924
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
912925
int offset = threadsPerWarp;
913926
Type offsetType = getTypeConverter()->getIndexType();
@@ -922,7 +935,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
922935
Value wiStride =
923936
rewriter.create<LLVM::ConstantOp>(loc, offsetType, threadsPerWarp);
924937
Value sgStride = rewriter.create<LLVM::ConstantOp>(
925-
loc, offsetType, threadsPerWarp * threadsPerWarp);
938+
loc, offsetType, threadsPerWarp * numElements);
926939
Value subGroupOffset = mul(sgStride, subGroupId);
927940
Type elementType = opType.getElementType();
928941
Value subGroupBasePtr = gep(ptrType, elementType, smemBase,
@@ -939,13 +952,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
939952
}
940953

941954
// Load from matrix, non-trasposed.
955+
// As per SIMD block semantics, we have stored the elements in a matrix of
956+
// `Nxsub_group_size` size, so we need to load back in blocks of
957+
// `sub_group_size` (`N/sub_group_size` loads).
942958
Value workItemOffset = mul(wiStride, subGroupLocalId);
943959
Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr,
944960
ValueRange{workItemOffset}, /*inbounds=*/true);
945-
Value transposedVec =
946-
load(vec_ty(opType.getElementType(), inVals.size()), workItemBasePtr);
947-
948-
return unwrapFromVector(loc, transposedVec, rewriter);
961+
SmallVector<Value> transposedVecs;
962+
Type loadTy = vec_ty(opType.getElementType(), threadsPerWarp);
963+
for (std::size_t i = 0, n = inVals.size(); i < n; i += threadsPerWarp) {
964+
transposedVecs.push_back(load(loadTy, workItemBasePtr));
965+
workItemBasePtr = gep(ptrType, loadTy, workItemBasePtr,
966+
ArrayRef<LLVM::GEPArg>{offset}, /*inbounds=*/true);
967+
}
968+
return unwrapFromVectors(loc, transposedVecs, rewriter);
949969
}
950970

951971
LogicalResult

0 commit comments

Comments
 (0)