Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cute/tutorial/xe_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ gemm_device(ATensor const& A, // (M,K)
/* Create block 2D TiledCopies */
auto copy_a = make_block_2d_copy_A(mma, A);
auto copy_b = make_block_2d_copy_B(mma, B);
auto copy_c = make_block_2d_copy_C(mma, C);
auto copy_c = make_block_2d_copy_D(mma, C);

/* Slice TiledCopy/TiledMMA operations to thread (work-item) level */
auto thr_mma = mma.get_slice(local_id);
Expand Down
56 changes: 40 additions & 16 deletions include/cute/atom/copy_traits_xe_2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,15 +911,25 @@ make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance
return make_block_2d_copy_C<ValType>(mma, gmem.stride()).with(gmem);
}

template <class TiledMMA, class CopyOp, class GEngine, class GLayout>
template <class TiledMMA, class GEngine, class GLayout>
CUTE_HOST_DEVICE
auto
make_block_2d_copy_C(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
make_block_2d_copy_D(TiledMMA const& mma, // TiledMMA instance
Tensor<GEngine, GLayout> const& gmem) // Global tensor
{
using ValType = typename GEngine::value_type;
return make_block_2d_copy_C<ValType>(op, mma, gmem.stride()).with(gmem);
return make_block_2d_copy_D<ValType>(mma, gmem.stride()).with(gmem);
}

template <class TiledMMA, class CopyOp, class GEngine, class GLayout>
CUTE_HOST_DEVICE
auto
make_block_2d_copy_CD(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
Tensor<GEngine, GLayout> const& gmem) // Global tensor
{
using ValType = typename GEngine::value_type;
return make_block_2d_copy_CD<ValType>(op, mma, gmem.stride()).with(gmem);
}

template <class ValType, class TiledMMA, class... Strides>
Expand All @@ -928,32 +938,46 @@ auto
make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance
Stride<Strides...> const& gstride) // Global memory strides
{
using MMAType = typename TiledMMA::ValTypeA;
using MMAType = typename TiledMMA::ValTypeC;
auto cC = make_identity_tensor(select<0,1>(mma.tile_mnk()));
auto op = block_2d_selector<ValType, MMAType, true>(
auto op = block_2d_selector<ValType, MMAType>(
mma.get_slice(0).atom_partition_C(cC).layout(), gstride
);
return make_block_2d_copy_C<ValType>(op, mma, gstride);
return make_block_2d_copy_CD<ValType>(op, mma, gstride);
}

template <class ValType, class TiledMMA, class CopyOp, class... Strides>
template <class ValType, class TiledMMA, class... Strides>
CUTE_HOST_DEVICE
auto
make_block_2d_copy_C(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
make_block_2d_copy_D(TiledMMA const& mma, // TiledMMA instance
Stride<Strides...> const& gstride) // Global memory strides
{
return make_block_2d_copy_C<ValType>(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride));
using MMAType = typename TiledMMA::ValTypeD;
auto cD = make_identity_tensor(select<0,1>(mma.tile_mnk()));
auto op = block_2d_selector<ValType, MMAType, true>(
mma.get_slice(0).atom_partition_C(cD).layout(), gstride
);
return make_block_2d_copy_CD<ValType>(op, mma, gstride);
}

template <class ValType, class TiledMMA, class CopyOp, class... Strides>
CUTE_HOST_DEVICE
auto
make_block_2d_copy_CD(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
Stride<Strides...> const& gstride) // Global memory strides
{
return make_block_2d_copy_CD<ValType>(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride));
}

template <class ValType, class TiledMMA, class CopyOp, class... Strides, class XMode, class YMode>
CUTE_HOST_DEVICE
auto
make_block_2d_copy_C(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
Stride<Strides...> const& gstride, // Global memory strides
XMode const& x_mode, // x, y modes
YMode const& y_mode)
make_block_2d_copy_CD(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
Stride<Strides...> const& gstride, // Global memory strides
XMode const& x_mode, // x, y modes
YMode const& y_mode)
{
// Retrieve MMA atom's (subgroup, value) -> (M,N) layout
auto tile_mn = select<0,1>(mma.tile_mnk());
Expand Down
22 changes: 18 additions & 4 deletions media/docs/cpp/xe_rearchitecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ struct Copy_Traits</* Op */, XMode, YMode, ValType, TiledStrides>;

Since it can be a tricky to correctly choose block 2D parameters and set up an appropriate tiling, we introduce several helpers for creating TiledCopy objects.

The high-level APIs `make_block_2d_copy_{A,B,C}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically.
The high-level APIs `make_block_2d_copy_{A,B,C,D}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically. Note that `make_block_2d_copy_C` and `make_block_2d_copy_D` only differ in their choice of a load (C) or store (D) operation.

```c++
template <class Engine, class Layout, /*...*/>
Expand All @@ -167,6 +167,12 @@ CUTE_DEVICE
TiledCopy<...>
make_block_2d_copy_C(const TiledMMA<...>&,
const Tensor<Engine, Layout>& gmem); // (M,N,...)

template <class Engine, class Layout, /*...*/>
CUTE_DEVICE
TiledCopy<...>
make_block_2d_copy_D(const TiledMMA<...>&,
const Tensor<Engine, Layout>& gmem); // (M,N,...)
```

The user may also override the choice of copy operation:
Expand All @@ -179,7 +185,15 @@ make_block_2d_copy_A(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
Tensor<GEngine, GLayout> const& gmem); // Global tensor

/* Similarly for B/C */
/* Similarly for B */

/* Single routine for both C/D */
template <class TiledMMA, class CopyOp, class GEngine, class GLayout>
CUTE_HOST_DEVICE
auto
make_block_2d_copy_CD(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
Tensor<GEngine, GLayout> const& gmem); // Global tensor
```

The `make_block_2d_copy_*` family of functions create TiledCopy objects that match the scope of the TiledMMA. That is, the set of threads participating in the TiledMMA will also participate in the TiledCopy.
Expand All @@ -194,7 +208,7 @@ TiledCopy
make_block_2d_copy(const CopyOp& op, const Tensor<Engine, Layout>& gmem);
```

For advanced usage, there are additional overloads of `make_block_2d_copy` that allow more general work distributions for copies (see `include/cute/atom/copy_traits_xe_2d.hpp`).
For advanced usage, there are additional overloads of `make_block_2d_copy` in which multiple subgroups participate (see `include/cute/atom/copy_traits_xe_2d.hpp`).

As the `CUTE_DEVICE` decorators imply, all the APIs above should be called from device code only, as they set up internal state that cannot be transferred from host to device.

Expand Down Expand Up @@ -419,7 +433,7 @@ gemm_device(ATensor const& A, // (M,K)
/* Create block 2D TiledCopies */
auto copy_a = make_block_2d_copy_A(mma, A);
auto copy_b = make_block_2d_copy_B(mma, B);
auto copy_c = make_block_2d_copy_C(mma, C);
auto copy_c = make_block_2d_copy_D(mma, C);

/* Slice TiledCopy/TiledMMA operations to thread (work-item) level */
auto thr_mma = mma.get_slice(local_id);
Expand Down
Loading