- 
                Notifications
    You must be signed in to change notification settings 
- Fork 64
Use newer version of copy_atom in epilogue collective #573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
578ae95
              aaf2685
              407a875
              45ac04e
              12282e8
              6596ac8
              7ea69d5
              6944d90
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -68,7 +68,7 @@ template < | |
| class CopyOpR2S_ | ||
| > | ||
| class CollectiveEpilogue< | ||
| IntelXeXMX16, | ||
| IntelXeL1Staged, | ||
| CtaTileMNK_, | ||
| ElementC_, | ||
| StrideC_, | ||
|  | @@ -86,7 +86,7 @@ class CollectiveEpilogue< | |
| // | ||
| // Type Aliases | ||
| // | ||
| using DispatchPolicy = IntelXeXMX16; | ||
| using DispatchPolicy = IntelXeL1Staged; | ||
| using CtaTileMNK = CtaTileMNK_; | ||
| using FusionCallbacks = FusionCallbacks_; | ||
| using ElementC = ElementC_; | ||
|  | @@ -102,9 +102,6 @@ class CollectiveEpilogue< | |
| using CopyOpR2S = CopyOpR2S_; | ||
|  | ||
| using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits<FusionCallbacks>::Operation; | ||
| using GmemTiledCopyC = conditional_t<cute::is_void_v<CopyOpG2R>, XE_2D_U32x8x16_LD_N, CopyOpG2R>; | ||
| using GmemTiledCopyD = cute::conditional_t<not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>, | ||
| CopyOpR2G, XE_2D_U32x8x16_ST_N>; | ||
| using ElementOutput = ElementD; | ||
| using ElementCompute = ElementAccumulator; | ||
|  | ||
|  | @@ -119,19 +116,10 @@ class CollectiveEpilogue< | |
| static_assert(std::is_same_v<SmemLayoutAtomC, void>, "Copy operation to shared memory is not supported"); | ||
| static_assert(std::is_same_v<SmemLayoutAtomD, void>, "Copy operation to shared memory is not supported"); | ||
|  | ||
| using CopyThreadShape = Shape<_1, Int<SubgroupSize>>; | ||
|  | ||
| using Trait_C = Copy_Traits<GmemTiledCopyC, StrideC>; | ||
| using val_layout_load_C = decltype(make_layout(shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{}))); | ||
| using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom<Trait_C, ElementC>{}, Layout<CopyThreadShape>{}, val_layout_load_C{})); | ||
|  | ||
| using Trait_D = Copy_Traits<GmemTiledCopyD, StrideD>; | ||
| using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))); | ||
| using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom<Trait_D, ElementD>{}, Layout<CopyThreadShape>{}, val_layout_store_D{})); | ||
|  | ||
| //remember this PR https://github.com/intel/sycl-tla/pull/565/files | ||
| private: | ||
| constexpr static bool is_source_supported = not cute::is_void_v<ElementC> && not cute::is_void_v<CopyOpG2R>; | ||
| constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>; | ||
| constexpr static bool is_source_supported = not cute::is_void_v<ElementC>; | ||
| constexpr static bool is_destination_supported = not cute::is_void_v<ElementD>; | ||
|  | ||
| constexpr static bool is_m_major_C = detail::is_m_major<StrideC>(); | ||
| constexpr static bool is_m_major_D = detail::is_m_major<StrideD>(); | ||
|  | @@ -154,6 +142,15 @@ class CollectiveEpilogue< | |
| }; | ||
| using TensorStorage = typename SharedStorage::TensorStorage; | ||
|  | ||
| // Helper to get tensor types | ||
| template<class Element, class Stride> | ||
| using TensorTypeC = decltype(make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), | ||
| make_layout(make_shape(int{}, int{}, int{}), Stride{}))); | ||
|  | ||
| template<class Element, class Stride> | ||
| using TensorTypeD = decltype(make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), | ||
| make_layout(make_shape(int{}, int{}, int{}), Stride{}))); | ||
|  | ||
| // Host side epilogue arguments | ||
| struct Arguments { | ||
| typename FusionCallbacks::Arguments thread{}; | ||
|  | @@ -166,8 +163,8 @@ class CollectiveEpilogue< | |
| // Device side epilogue params | ||
| struct Params { | ||
| typename FusionCallbacks::Params thread{}; | ||
| XE_Copy_C xe_load_c; | ||
| XE_Copy_D xe_store_d; | ||
| TensorTypeC<ElementC, StrideC> mC; | ||
| TensorTypeD<ElementD, StrideD> mD; | ||
| }; | ||
|  | ||
| // | ||
|  | @@ -183,23 +180,13 @@ class CollectiveEpilogue< | |
| // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) | ||
| auto problem_shape_MNKL = append<4>(problem_shape, 1); | ||
| auto [M, N, K, L] = problem_shape_MNKL; | ||
|  | ||
| XE_Copy_C xe_load_c = {}; | ||
| if constexpr (is_source_supported) { | ||
| auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); | ||
| xe_load_c = {xe_load_c.with(mC)}; | ||
| } | ||
|  | ||
| XE_Copy_D xe_store_d = {}; | ||
| if constexpr (is_destination_supported) { | ||
| auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); | ||
| xe_store_d = {xe_store_d.with(mD)}; | ||
| } | ||
| auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC)); | ||
| auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD)); | ||
|  | ||
| return { | ||
| FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), | ||
| xe_load_c, | ||
| xe_store_d, | ||
| mC, | ||
| mD | ||
| }; | ||
| } | ||
|  | ||
|  | @@ -270,6 +257,24 @@ class CollectiveEpilogue< | |
| return fusion_callbacks.is_producer_load_needed(); | ||
| } | ||
|  | ||
| template<typename Tensor> | ||
| CUTLASS_DEVICE auto reshape_into_smaller_fragments(Tensor&& tensor) { | ||
|         
                  sanchitintel marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| using namespace cute; | ||
|  | ||
| auto target_stride = make_stride( | ||
| make_stride(cute::ScaledBasis<cute::Int<1>, 0>{}, _0{}), | ||
| cute::ScaledBasis<cute::Int<8>, 0>{}, | ||
| cute::ScaledBasis<cute::Int<16>, 1>{} | ||
| ); | ||
|  | ||
| auto target_layout = make_layout( | ||
| make_shape(make_shape(_8{}, _1{}), _4{}, _4{}), | ||
| target_stride | ||
|         
                  sanchitintel marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| ); | ||
|         
                  sanchitintel marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
|  | ||
| return make_tensor(tensor.data(), target_layout); | ||
| } | ||
|  | ||
| template< | ||
| class ProblemShapeMNKL, | ||
| class TileShapeMNK, | ||
|  | @@ -286,7 +291,6 @@ class CollectiveEpilogue< | |
| TiledMma tiled_mma, | ||
| int thread_idx) { | ||
|  | ||
| (void) tiled_mma; | ||
| using namespace cute; | ||
|  | ||
| static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); | ||
|  | @@ -297,12 +301,11 @@ class CollectiveEpilogue< | |
| static constexpr auto BLK_M = get<0>(CtaTileMNK{}); | ||
| static constexpr auto BLK_N = get<1>(CtaTileMNK{}); | ||
| static constexpr auto BLK_K = get<2>(CtaTileMNK{}); | ||
| // static_assert(is_same_v<typename TiledMma::ThrLayoutVMNK, int>, "assertation fail"); | ||
| static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); | ||
| static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); | ||
| static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); | ||
| static_assert( | ||
|  | ||
| static_assert( | ||
| BLK_M % ATOM_M == 0 && | ||
| BLK_N % ATOM_N == 0 && | ||
| BLK_K % ATOM_K == 0, | ||
|  | @@ -316,46 +319,49 @@ class CollectiveEpilogue< | |
| static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group | ||
|  | ||
| static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; | ||
|  | ||
| // Indexing variables | ||
| auto [M, N, K, L] = problem_shape_mnkl; | ||
| auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; | ||
| auto m_sg = get_sub_group_id() / ATOM_N; | ||
| auto n_sg = get_sub_group_id() % ATOM_N; | ||
|  | ||
| auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); | ||
|  | ||
| auto sg_local_m_coord = get_sub_group_id() / ATOM_N; | ||
| auto sg_local_n_coord = get_sub_group_id() % ATOM_N; | ||
|  | ||
| auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; | ||
| auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; | ||
| auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); | ||
|  | ||
|  | ||
| auto wg_coord = make_coord(m_coord, n_coord, k_coord, l_coord); | ||
| bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); | ||
|  | ||
| auto batch_idx = get<3>(wg_coord); | ||
| auto copy_c = get_block_2d_copy_C<CopyOpG2R>(tiled_mma, params.mC(_,_,batch_idx)); | ||
| auto copy_d = get_block_2d_copy_D<CopyOpR2G>(tiled_mma, params.mD(_,_,batch_idx)); | ||
|  | ||
|  | ||
|  | ||
| // Represent the full output tensor | ||
| Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L)); | ||
|  | ||
| // Tile the output tensor per WG and select the tile for current WG | ||
| Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) | ||
|  | ||
| // Tile the output tensor per SG and select tile for the current SG | ||
| Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N) | ||
| // Tile the output tensor for the current workgroup | ||
| Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N) | ||
|  | ||
| auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx); | ||
| // Get thread-level partitioning across the entire workgroup tile | ||
| auto thread_xe_load_c = copy_c.get_thread_slice(thread_idx); | ||
| Tensor tCgC = thread_xe_load_c.partition_S(gD); | ||
|          | ||
|  | ||
| auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); | ||
| auto thread_xe_store_d = copy_d.get_thread_slice(thread_idx); | ||
| Tensor tCgD = thread_xe_store_d.partition_D(gD); | ||
|  | ||
| auto tCgC_frag = reshape_into_smaller_fragments(tCgC); | ||
| auto tCgD_frag = reshape_into_smaller_fragments(tCgD); | ||
|  | ||
| Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{}); | ||
| Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{}); | ||
|  | ||
| // Because Sm90 uses shared memory, they are not tied to using the same accumulator values | ||
| // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be | ||
| // sure that we are operating on the same values. | ||
| ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); | ||
| ThrCopy thread_g2r = copy_c.get_slice(thread_idx); | ||
| auto mn_shape = shape(typename decltype(copy_d)::Tiler_MN{}); | ||
|  | ||
| // OOB predication for tile quantization "residue" | ||
| // Absolute coordinate tensors (dynamic) | ||
|  | @@ -364,7 +370,7 @@ class CollectiveEpilogue< | |
| Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) | ||
| Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) | ||
|  | ||
| Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) | ||
| Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); | ||
|  | ||
| // Get the fusion callbacks | ||
| // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles | ||
|  | @@ -376,7 +382,7 @@ class CollectiveEpilogue< | |
| sg_coord, | ||
| tiled_mma, | ||
| mn_shape, | ||
| params.xe_store_d, | ||
| copy_d, | ||
| cD, | ||
| residue_mn, | ||
| tRS_cD, | ||
|  | @@ -398,7 +404,8 @@ class CollectiveEpilogue< | |
| FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; | ||
| constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); | ||
| static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); | ||
|  | ||
|  | ||
|  | ||
| auto synchronize = [&] () {}; | ||
| CUTLASS_PRAGMA_UNROLL | ||
| for (int epi_n = 0; epi_n < FragsN; epi_n++) { | ||
|  | @@ -407,7 +414,7 @@ class CollectiveEpilogue< | |
| cst_callbacks.begin_loop(epi_m, epi_n); | ||
|  | ||
| if (is_C_load_needed) { | ||
| copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); | ||
| copy(copy_c, tCgC_frag(_, epi_m, epi_n), trC); | ||
| } | ||
|  | ||
| cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); | ||
|  | @@ -419,21 +426,23 @@ class CollectiveEpilogue< | |
| trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); | ||
| } | ||
| cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); | ||
|  | ||
| if constexpr (is_destination_supported) { | ||
| CUTLASS_PRAGMA_UNROLL | ||
| for (int i = 0; i < size(trD_compute_frag); ++i) { | ||
| trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>{}(trD_compute_frag(i)); | ||
| } | ||
| copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); | ||
| copy(copy_d, trD, tCgD_frag(_, epi_m, epi_n)); | ||
| } | ||
|  | ||
| cst_callbacks.end_loop(epi_m, epi_n); | ||
|  | ||
| } | ||
| } | ||
|  | ||
| cst_callbacks.end(); | ||
| } | ||
|  | ||
| } | ||
|  | ||
| private: | ||
| Params const& params; | ||
|  | @@ -447,4 +456,4 @@ class CollectiveEpilogue< | |
| } // namespace epilogue | ||
| } // namespace cutlass | ||
|  | ||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | ||
| ///////////////////////////////////////////////////////////////////////////////////////////////// | ||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jiyang1011 - The validation logic from PR #565 that sets is_source_supported to false when CopyOpG2R is void needs updating. With this PR's automatic ops selection, both CopyOpG2R and CopyOpR2G can now legitimately be void since make_block_2d_copy_* automatically selects appropriate operations.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we set a default copy trait like XeCopyAuto or something else which will also call make_block_2d_copy_* ?