diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir index def61f6e73..8b2c5bd6aa 100644 --- a/test/Conversion/intel/sub-group-transpose.mlir +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -297,3 +297,132 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 tt.return %0 : tensor<64x16x4xf32, #blocked1> } } + +// ----- + +// Test transposition with 32 elements per work-item. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1> + tt.return %0 : tensor<32x16xf32, #blocked1> + } +} + +// ----- + +// Test transposition with 32 elements per work-item with a different layout. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1> + tt.return %0 : tensor<16x32xf32, #blocked1> + } +} + +// ----- + +// Test transposition with 32 elements per work-item and two warps in each dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1> + tt.return %0 : tensor<32x64xf32, #blocked1> + } +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a03989d765..aaec293bac 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -462,17 +462,21 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Expected conversion is: // - register=1 -> (0, 1) // ... - // - register=i -> (0, 2**(i-1)) + // - register=2**i -> (0, 2**i) // ... - // - register=N -> (0, 2**(N-1)) + // - register=M -> (0, 2**M) + // ... + // - register=2**k -> (2**k, 0) + // ... + // - register=N -> (2**N, 0) // - lane=1 -> (0, 1) // ... - // - lane=j -> (2**(j-1), 0) + // - lane=2**j -> (2**j, 0) // ... - // lane=M -> (2**(M-1), 0) - // where out dims are: [register (size 2**(N-1)), lane (size 2**(M-1))] + // lane=2**M -> (2**M, 0) + // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))] // - // With N = M. + // With N >= M. const auto buildBasis = [&](int32_t size, std::size_t index) { std::vector> basis; std::vector curr(2); @@ -482,13 +486,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } return basis; }; - constexpr std::size_t laneIndex = 0; constexpr std::size_t registerIndex = 1; - int32_t size = conversion->getInDimSize(kLane); + int32_t laneSize = conversion->getInDimSize(kLane); + std::vector> registerBases = + buildBasis(laneSize, registerIndex); + { + // Populate register bases for N > M. + std::vector base(2); + for (int32_t i = laneSize, + registerSize = conversion->getInDimSize(kRegister); + i < registerSize; i *= 2) { + base[laneIndex] = i; + registerBases.push_back(base); + } + } std::array>>, 2> - bases{{{kRegister, buildBasis(size, registerIndex)}, - {kLane, buildBasis(size, laneIndex)}}}; + bases{{{kRegister, std::move(registerBases)}, + {kLane, buildBasis(laneSize, laneIndex)}}}; std::array outDimNames{kRegister, kLane}; return conversion == LinearLayout(bases, outDimNames); } @@ -739,11 +754,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion OpAdaptor adaptor) const { auto srcType = cast(adaptor.getSrc().getType()); ArrayRef body = srcType.getBody(); - // TODO: Support more configurations. - auto mod = op->getParentOfType(); - int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - if (body.size() != threadsPerWarp) - return false; return TypeSwitch(body.front()) .Case([this](FloatType floatTy) { // Support via bitcasting to integer type. @@ -888,12 +898,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } SmallVector - unwrapFromVector(Location loc, Value vec, - ConversionPatternRewriter &rewriter) const { + unwrapFromVectors(Location loc, ArrayRef vecs, + ConversionPatternRewriter &rewriter) const { SmallVector res; - for (unsigned i = 0, n = cast(vec.getType()).getShape()[0]; - i < n; ++i) - res.push_back(extract_element(vec, i32_val(i))); + for (Value vec : vecs) { + for (unsigned i = 0, n = cast(vec.getType()).getShape()[0]; + i < n; ++i) + res.push_back(extract_element(vec, i32_val(i))); + } return res; } @@ -908,6 +920,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion loc, rewriter, targetInfo, &*rewriter.getInsertionPoint()); Type ptrType = smemBase.getType(); + int numElements = inVals.size(); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); int offset = threadsPerWarp; Type offsetType = getTypeConverter()->getIndexType(); @@ -922,7 +935,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion Value wiStride = rewriter.create(loc, offsetType, threadsPerWarp); Value sgStride = rewriter.create( - loc, offsetType, threadsPerWarp * threadsPerWarp); + loc, offsetType, threadsPerWarp * numElements); Value subGroupOffset = mul(sgStride, subGroupId); Type elementType = opType.getElementType(); Value subGroupBasePtr = gep(ptrType, elementType, smemBase, @@ -939,13 +952,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } // Load from matrix, non-trasposed. + // As per SIMD block semantics, we have stored the elements in a matrix of + // `Nxsub_group_size` size, so we need to load back in blocks of + // `sub_group_size` (`N/sub_group_size` loads). Value workItemOffset = mul(wiStride, subGroupLocalId); Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr, ValueRange{workItemOffset}, /*inbounds=*/true); - Value transposedVec = - load(vec_ty(opType.getElementType(), inVals.size()), workItemBasePtr); - - return unwrapFromVector(loc, transposedVec, rewriter); + SmallVector transposedVecs; + Type loadTy = vec_ty(opType.getElementType(), threadsPerWarp); + for (std::size_t i = 0, n = inVals.size(); i < n; i += threadsPerWarp) { + transposedVecs.push_back(load(loadTy, workItemBasePtr)); + workItemBasePtr = gep(ptrType, loadTy, workItemBasePtr, + ArrayRef{offset}, /*inbounds=*/true); + } + return unwrapFromVectors(loc, transposedVecs, rewriter); } LogicalResult