Skip to content

Commit 602175b

Browse files
committed
[TritonIntelGPUToLLVM] Extend sub-group transposition support
Extend sub-group transposition support allowing `N*sub_group_size` elements per thread. As per block load semantics (matrix of `sub_group_size` columns), we need `N` vector loads to load the transposed matrix from local memory. Signed-off-by: victor-eds <[email protected]>
1 parent 57b4375 commit 602175b

File tree

2 files changed

+164
-17
lines changed

2 files changed

+164
-17
lines changed

test/Conversion/intel/sub-group-transpose.mlir

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,132 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16
297297
tt.return %0 : tensor<64x16x4xf32, #blocked1>
298298
}
299299
}
300+
301+
// -----
302+
303+
// Test transposition with 32 elements per work-item.
304+
305+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
306+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
307+
308+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
309+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
310+
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
311+
tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> {
312+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
313+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
314+
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
315+
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
316+
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
317+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
318+
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
319+
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
320+
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
321+
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
322+
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
323+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
324+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
325+
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
326+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
327+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
328+
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
329+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
330+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
331+
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
332+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
333+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
334+
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
335+
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
336+
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
337+
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
338+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
339+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1>
340+
tt.return %0 : tensor<32x16xf32, #blocked1>
341+
}
342+
}
343+
344+
// -----
345+
346+
// Test transposition with 32 elements per work-item with a different layout.
347+
348+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
349+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
350+
351+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
352+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
353+
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
354+
tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> {
355+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
356+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
357+
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
358+
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
359+
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
360+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
361+
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
362+
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
363+
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
364+
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
365+
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
366+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
367+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
368+
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
369+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
370+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
371+
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
372+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
373+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
374+
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
375+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
376+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
377+
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
378+
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
379+
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
380+
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
381+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
382+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1>
383+
tt.return %0 : tensor<16x32xf32, #blocked1>
384+
}
385+
}
386+
387+
// -----
388+
389+
// Test transposition with 32 elements per work-item and two warps in each dimension.
390+
391+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}>
392+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
393+
394+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} {
395+
// CHECK-LABEL: llvm.func spir_kernelcc @test(
396+
// CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
397+
tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> {
398+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
399+
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
400+
// CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
401+
// CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
402+
// CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
403+
// CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
404+
// CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
405+
// CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
406+
// CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
407+
// CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
408+
// CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
409+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
410+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
411+
// CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
412+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
413+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
414+
// CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
415+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
416+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
417+
// CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
418+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
419+
// CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
420+
// CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
421+
// CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
422+
// CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
423+
// CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
424+
// CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
425+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1>
426+
tt.return %0 : tensor<32x64xf32, #blocked1>
427+
}
428+
}

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -462,13 +462,26 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
462462
// register=2 -> (0, 2)
463463
// register=4 -> (0, 4)
464464
// register=8 -> (0, 8)
465+
// register=N -> (N, 0)
466+
// ...
465467
// - lane=1 -> (1, 0)
466468
// lane=2 -> (2, 0)
467469
// lane=4 -> (4, 0)
468470
// lane=8 -> (8, 0)
469-
// where out dims are: [register (size 16), lane (size 16)]
471+
// where out dims are: [register (size 2*N), lane (size 16)]
472+
std::vector<std::vector<int32_t>> registerBases{
473+
{0, 1}, {0, 2}, {0, 4}, {0, 8}};
474+
{
475+
// Populate register bases for N > 8.
476+
std::vector<int32_t> base(2);
477+
for (int32_t i = 16, n = conversion->getInDimSize(kRegister); i < n;
478+
i *= 2) {
479+
base.front() = i;
480+
registerBases.push_back(base);
481+
}
482+
}
470483
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
471-
bases{{{kRegister, {{0, 1}, {0, 2}, {0, 4}, {0, 8}}},
484+
bases{{{kRegister, std::move(registerBases)},
472485
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}}};
473486
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
474487
return conversion == LinearLayout(bases, outDimNames);
@@ -572,11 +585,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
572585
OpAdaptor adaptor) const {
573586
auto srcType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
574587
ArrayRef<Type> body = srcType.getBody();
575-
// TODO: Support more configurations.
576-
auto mod = op->getParentOfType<ModuleOp>();
577-
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
578-
if (body.size() != threadsPerWarp)
579-
return false;
580588
return TypeSwitch<Type, bool>(body.front())
581589
.Case([this](FloatType floatTy) {
582590
// Support via bitcasting to integer type.
@@ -714,12 +722,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
714722
}
715723

716724
SmallVector<Value>
717-
unwrapFromVector(Location loc, Value vec,
718-
ConversionPatternRewriter &rewriter) const {
725+
unwrapFromVectors(Location loc, ArrayRef<Value> vecs,
726+
ConversionPatternRewriter &rewriter) const {
719727
SmallVector<Value> res;
720-
for (unsigned i = 0, n = cast<VectorType>(vec.getType()).getShape()[0];
721-
i < n; ++i)
722-
res.push_back(extract_element(vec, i32_val(i)));
728+
for (Value vec : vecs) {
729+
for (unsigned i = 0, n = cast<VectorType>(vec.getType()).getShape()[0];
730+
i < n; ++i)
731+
res.push_back(extract_element(vec, i32_val(i)));
732+
}
723733
return res;
724734
}
725735

@@ -734,6 +744,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
734744
loc, rewriter, targetInfo, &*rewriter.getInsertionPoint());
735745
Type ptrType = smemBase.getType();
736746

747+
int numElements = inVals.size();
737748
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
738749
int offset = threadsPerWarp;
739750
Type offsetType = getTypeConverter()->getIndexType();
@@ -748,7 +759,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
748759
Value wiStride =
749760
rewriter.create<LLVM::ConstantOp>(loc, offsetType, threadsPerWarp);
750761
Value sgStride = rewriter.create<LLVM::ConstantOp>(
751-
loc, offsetType, threadsPerWarp * threadsPerWarp);
762+
loc, offsetType, threadsPerWarp * numElements);
752763
Value subGroupOffset = mul(sgStride, subGroupId);
753764
Type elementType = opType.getElementType();
754765
Value subGroupBasePtr = gep(ptrType, elementType, smemBase,
@@ -765,13 +776,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
765776
}
766777

767778
// Load from matrix, trasposed.
779+
// As per SIMD block semantics, we have stored the elements in a matrix of
780+
// `Nxsub_group_size` size, so we need to load back in blocks of
781+
// `sub_group_size` (`N/sub_group_size` loads).
768782
Value workItemOffset = mul(wiStride, subGroupLocalId);
769783
Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr,
770784
ValueRange{workItemOffset}, /*inbounds=*/true);
771-
Value transposedVec =
772-
load(vec_ty(opType.getElementType(), inVals.size()), workItemBasePtr);
773-
774-
return unwrapFromVector(loc, transposedVec, rewriter);
785+
SmallVector<Value> transposedVecs;
786+
Type loadTy = vec_ty(opType.getElementType(), threadsPerWarp);
787+
for (std::size_t i = 0, n = inVals.size(); i < n; i += threadsPerWarp) {
788+
transposedVecs.push_back(load(loadTy, workItemBasePtr));
789+
workItemBasePtr = gep(ptrType, loadTy, workItemBasePtr,
790+
ArrayRef<LLVM::GEPArg>{offset}, /*inbounds=*/true);
791+
}
792+
return unwrapFromVectors(loc, transposedVecs, rewriter);
775793
}
776794

777795
LogicalResult

0 commit comments

Comments
 (0)