diff --git a/examples/cute/tutorial/xe_gemm.cpp b/examples/cute/tutorial/xe_gemm.cpp index 887c23599d..ce82b6be88 100644 --- a/examples/cute/tutorial/xe_gemm.cpp +++ b/examples/cute/tutorial/xe_gemm.cpp @@ -86,7 +86,7 @@ gemm_device(ATensor const& A, // (M,K) /* Create block 2D TiledCopies */ auto copy_a = make_block_2d_copy_A(mma, A); auto copy_b = make_block_2d_copy_B(mma, B); - auto copy_c = make_block_2d_copy_C(mma, C); + auto copy_c = make_block_2d_copy_D(mma, C); /* Slice TiledCopy/TiledMMA operations to thread (work-item) level */ auto thr_mma = mma.get_slice(local_id); diff --git a/include/cute/atom/copy_traits_xe_2d.hpp b/include/cute/atom/copy_traits_xe_2d.hpp index 2df8ae0a38..adb124673a 100644 --- a/include/cute/atom/copy_traits_xe_2d.hpp +++ b/include/cute/atom/copy_traits_xe_2d.hpp @@ -911,15 +911,25 @@ make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance return make_block_2d_copy_C(mma, gmem.stride()).with(gmem); } -template +template CUTE_HOST_DEVICE auto -make_block_2d_copy_C(CopyOp const& op, // Copy operation - TiledMMA const& mma, // TiledMMA instance +make_block_2d_copy_D(TiledMMA const& mma, // TiledMMA instance Tensor const& gmem) // Global tensor { using ValType = typename GEngine::value_type; - return make_block_2d_copy_C(op, mma, gmem.stride()).with(gmem); + return make_block_2d_copy_D(mma, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_CD(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_CD(op, mma, gmem.stride()).with(gmem); } template @@ -928,32 +938,46 @@ auto make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance Stride const& gstride) // Global memory strides { - using MMAType = typename TiledMMA::ValTypeA; + using MMAType = typename TiledMMA::ValTypeC; auto cC = make_identity_tensor(select<0,1>(mma.tile_mnk())); - auto op = block_2d_selector( + auto op = block_2d_selector( mma.get_slice(0).atom_partition_C(cC).layout(), gstride ); - return make_block_2d_copy_C(op, mma, gstride); + return make_block_2d_copy_CD(op, mma, gstride); } -template +template CUTE_HOST_DEVICE auto -make_block_2d_copy_C(CopyOp const& op, // Copy operation - TiledMMA const& mma, // TiledMMA instance +make_block_2d_copy_D(TiledMMA const& mma, // TiledMMA instance Stride const& gstride) // Global memory strides { - return make_block_2d_copy_C(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride)); + using MMAType = typename TiledMMA::ValTypeD; + auto cD = make_identity_tensor(select<0,1>(mma.tile_mnk())); + auto op = block_2d_selector( + mma.get_slice(0).atom_partition_C(cD).layout(), gstride + ); + return make_block_2d_copy_CD(op, mma, gstride); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_CD(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride) // Global memory strides +{ + return make_block_2d_copy_CD(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride)); } template CUTE_HOST_DEVICE auto -make_block_2d_copy_C(CopyOp const& op, // Copy operation - TiledMMA const& mma, // TiledMMA instance - Stride const& gstride, // Global memory strides - XMode const& x_mode, // x, y modes - YMode const& y_mode) +make_block_2d_copy_CD(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride, // Global memory strides + XMode const& x_mode, // x, y modes + YMode const& y_mode) { // Retrieve MMA atom's (subgroup, value) -> (M,N) layout auto tile_mn = select<0,1>(mma.tile_mnk()); diff --git a/media/docs/cpp/xe_rearchitecture.md b/media/docs/cpp/xe_rearchitecture.md index 7e5c49bfb4..45021aba51 100644 --- a/media/docs/cpp/xe_rearchitecture.md +++ b/media/docs/cpp/xe_rearchitecture.md @@ -147,7 +147,7 @@ struct Copy_Traits; Since it can be a tricky to correctly choose block 2D parameters and set up an appropriate tiling, we introduce several helpers for creating TiledCopy objects. -The high-level APIs `make_block_2d_copy_{A,B,C}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically. +The high-level APIs `make_block_2d_copy_{A,B,C,D}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically. Note that `make_block_2d_copy_C` and `make_block_2d_copy_D` only differ in their choice of a load (C) or store (D) operation. ```c++ template @@ -167,6 +167,12 @@ CUTE_DEVICE TiledCopy<...> make_block_2d_copy_C(const TiledMMA<...>&, const Tensor& gmem); // (M,N,...) + +template +CUTE_DEVICE +TiledCopy<...> +make_block_2d_copy_D(const TiledMMA<...>&, + const Tensor& gmem); // (M,N,...) ``` The user may also override the choice of copy operation: @@ -179,7 +185,15 @@ make_block_2d_copy_A(CopyOp const& op, // Copy operation TiledMMA const& mma, // TiledMMA instance Tensor const& gmem); // Global tensor -/* Similarly for B/C */ +/* Similarly for B */ + +/* Single routine for both C/D */ +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_CD(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem); // Global tensor ``` The `make_block_2d_copy_*` family of functions create TiledCopy objects that match the scope of the TiledMMA. That is, the set of threads participating in the TiledMMA will also participate in the TiledCopy. @@ -194,7 +208,7 @@ TiledCopy make_block_2d_copy(const CopyOp& op, const Tensor& gmem); ``` -For advanced usage, there are additional overloads of `make_block_2d_copy` that allow more general work distributions for copies (see `include/cute/atom/copy_traits_xe_2d.hpp`). +For advanced usage, there are additional overloads of `make_block_2d_copy` in which multiple subgroups participate (see `include/cute/atom/copy_traits_xe_2d.hpp`). As the `CUTE_DEVICE` decorators imply, all the APIs above should be called from device code only, as they set up internal state that cannot be transferred from host to device. @@ -419,7 +433,7 @@ gemm_device(ATensor const& A, // (M,K) /* Create block 2D TiledCopies */ auto copy_a = make_block_2d_copy_A(mma, A); auto copy_b = make_block_2d_copy_B(mma, B); - auto copy_c = make_block_2d_copy_C(mma, C); + auto copy_c = make_block_2d_copy_D(mma, C); /* Slice TiledCopy/TiledMMA operations to thread (work-item) level */ auto thr_mma = mma.get_slice(local_id);