@@ -1486,7 +1486,11 @@ TEST_F(TransposeTest, NoTransposeMaverick17B) {
14861486class TransposeTMA
14871487 : public TransposeTest,
14881488 public testing::WithParamInterface<std::tuple<bool , bool >> {};
1489- TEST_P (TransposeTMA, TransposeTMALoadOptionalStore) {
1489+
1490+ // Transpose happens at input cached smem
1491+ // Each thread loads multiple columns from input and write as multiple rows to
1492+ // output.
1493+ TEST_P (TransposeTMA, TransposeInputSmem) {
14901494 NVFUSER_TEST_CUDA_ARCH_GUARD (9 , 0 );
14911495
14921496 Fusion fusion;
@@ -1666,4 +1670,155 @@ INSTANTIATE_TEST_SUITE_P(
16661670 std::string (
16671671 output_smem_swizzle ? " OutputSwizzle" : " NoOutputSwizzle" );
16681672 }));
1673+
1674+ // Transpose happens at output cached smem
1675+ // Each thread loads multiple rows from input and write as multiple columns to
1676+ // output.
1677+ // with tma load, 852 ms on GB200.
1678+ // without tma load, 814 ms on GB200.
1679+ TEST_F (TransposeTMA, TransposeOutputSmem) {
1680+ NVFUSER_TEST_CUDA_ARCH_GUARD (9 , 0 );
1681+ const bool use_tma_load = false ;
1682+
1683+ Fusion fusion;
1684+ FusionGuard fg (&fusion);
1685+
1686+ auto input = makeContigTensor (2 );
1687+ fusion.addInput (input);
1688+ auto output = transpose (input, 0 , 1 );
1689+ fusion.addOutput (output);
1690+
1691+ TensorView* input_smem_cache = nullptr ;
1692+ TensorView* input_reg_cache = nullptr ;
1693+ if (use_tma_load) {
1694+ input_smem_cache =
1695+ input->cacheAfter (LoadStoreOpType::CpAsyncBulkTensorTile);
1696+ input_smem_cache->setMemoryType (MemoryType::Shared);
1697+ input_reg_cache = input_smem_cache->cacheAfter ();
1698+ } else {
1699+ input_reg_cache = input->cacheAfter ();
1700+ }
1701+ TensorView* output_smem_cache = nullptr ;
1702+ output_smem_cache =
1703+ output->cacheBefore (LoadStoreOpType::CpAsyncBulkTensorTile);
1704+ output_smem_cache->setMemoryType (MemoryType::Shared);
1705+ output_smem_cache->cacheBefore ();
1706+
1707+ // gmem --> smem -> regs
1708+ // These tvs follow input layout [I0, I1].
1709+ TensorView* ref_tv = input_reg_cache;
1710+
1711+ // Swizzle parameters: 128-byte TMA swizzle with 16-byte chunks.
1712+ constexpr int64_t tma_swizzle_bytes = 128 ;
1713+ constexpr int64_t swizzle_chunk_bytes = 16 ;
1714+ const int64_t dtype_bytes =
1715+ dataTypeSizeByte (output_smem_cache->getDataType ().value ());
1716+ const int64_t elements_per_chunk = swizzle_chunk_bytes / dtype_bytes;
1717+ // tile_i1 must equal tma_swizzle_bytes / dtype_bytes.
1718+ const int64_t tile_i0 = 32 ;
1719+ NVF_ERROR_EQ (
1720+ tile_i0,
1721+ tma_swizzle_bytes / dtype_bytes,
1722+ " tile_i0 should span exactly 128 bytes" );
1723+ // tile_i1 and chunks_per_thread are tunable parameters.
1724+ const int64_t tile_i1 = 64 ;
1725+ const int64_t chunks_per_thread = 2 ;
1726+
1727+ // Step 1: Tile all tvs by tile_i0 (I0 dim) and tile_i1 (I1 dim).
1728+ // ref_tv has input layout [I0, I1].
1729+ // Target loop domain: [BIDx, tile_i0, tile_i1]
1730+ ref_tv->split (1 , tile_i1);
1731+ ref_tv->split (0 , tile_i0);
1732+ // [I0/tile_i0, tile_i0, I1/tile_i1, tile_i1]
1733+ ref_tv->reorder ({{-2 , 1 }});
1734+ ref_tv->merge (0 );
1735+ // [I0/tile_i0 * I1/tile_i1, tile_i0, tile_i1]
1736+ ref_tv->axis (0 )->parallelize (ParallelType::BIDx);
1737+ // [BIDx, tile_i0, tile_i1]
1738+
1739+ {
1740+ TransformPropagator propagator (ref_tv);
1741+ MaxLogicalDomainInfoSpanningTree entire_dag (ref_tv);
1742+ entire_dag.traverse (&propagator);
1743+ scheduler_utils::parallelizeAllLike (
1744+ ref_tv,
1745+ /* selected_tvs=*/ {},
1746+ /* selected_parallel_types=*/ {},
1747+ /* propagate_padding=*/ true ,
1748+ /* parallelize_inputs_on_did=*/ true );
1749+ }
1750+ // After propagation, all tvs have loop domain: [BIDx, tile_i0, tile_i1]
1751+ // ref_tv uses input layout (logical: [I0, I1]).
1752+ // input_smem_cache uses input layout (logical: [I0, I1]).
1753+
1754+ // Step 2: Schedule output TMA store (Bulk parallel on tile dims).
1755+ // output smem has output layout, [I1, I0]
1756+ // reorder to move inner most alloc dim to innermost position in loop
1757+ // domain
1758+ // [BIDx, tile_i0, tile_i1]
1759+ output_smem_cache->reorder ({{-1 , -2 }});
1760+ // [BIDx, tile_i1, tile_i0]
1761+ MmaInputSmemSwizzle swizzle =
1762+ mma_utils::tmaSwizzleSharedMemory (output_smem_cache);
1763+ mma_utils::scheduleTMAStoreForMmaOutput (output_smem_cache, swizzle);
1764+ mma_utils::scheduleTMAStoreForMmaOutput (output, swizzle);
1765+
1766+ // Step 3: Schedule input shared memory
1767+ // no swizzle, just contiguous load
1768+ // [BIDx, tile_i0, tile_i1]
1769+ if (use_tma_load) {
1770+ input_smem_cache->axis (1 )->parallelize (ParallelType::Bulk);
1771+ input_smem_cache->axis (2 )->parallelize (ParallelType::Bulk);
1772+ input_smem_cache->setAllocationDomain (
1773+ input_smem_cache->getLoopDomain (), true );
1774+ }
1775+ // Step 4: Schedule register tvs for per-thread access pattern.
1776+ // [BIDx, tile_i0, tile_i1]
1777+ ref_tv->split (-2 , elements_per_chunk);
1778+ // [BIDx, tile_i0/chunk, chunk, tile_i1]
1779+ ref_tv->split (-3 , chunks_per_thread);
1780+ // [BIDx, tile_i0/chunk/cpt, cpt, chunk, tile_i1]
1781+ ref_tv->merge (-4 , -1 );
1782+ // [BIDx, tile_i1/chunk/cpt * tile_i0, cpt, chunk]
1783+ ref_tv->axis (-3 )->parallelize (ParallelType::TIDx);
1784+ // [BIDx, TIDx, cpt, chunk]
1785+ ref_tv->axis (-1 )->parallelize (ParallelType::Unroll);
1786+
1787+ {
1788+ // Propagate Step 4 transforms to all tvs except those already
1789+ // independently scheduled (input_smem_cache and TMA store output).
1790+ std::unordered_set<TensorView*> skip_tvs;
1791+ if (use_tma_load) {
1792+ skip_tvs.insert (input_smem_cache);
1793+ }
1794+ skip_tvs.insert (output);
1795+ auto propagate_tvs = ir_utils::allTvsExcept (&fusion, skip_tvs);
1796+ std::unordered_set<TensorView*> propagate_tvs_set (
1797+ propagate_tvs.begin (), propagate_tvs.end ());
1798+ SetSelector selector (propagate_tvs_set);
1799+ MaxLogicalDomainInfoSpanningTree propagate_dag (ref_tv, &selector);
1800+ TransformPropagator propagator (ref_tv);
1801+ propagate_dag.traverse (&propagator);
1802+ scheduler_utils::parallelizeAllLike (
1803+ ref_tv,
1804+ /* selected_tvs=*/ {propagate_tvs},
1805+ /* selected_parallel_types=*/ {},
1806+ /* propagate_padding=*/ true ,
1807+ /* parallelize_inputs_on_did=*/ true );
1808+ }
1809+
1810+ output_smem_cache->axis (-1 )->parallelize (ParallelType::Vectorize);
1811+
1812+ inlineMost ();
1813+
1814+ auto options = at::TensorOptions ().dtype (at::kFloat ).device (at::kCUDA , 0 );
1815+ at::Tensor input0 = at::randn ({16384 , 32768 }, options);
1816+
1817+ KernelExecutor ke;
1818+ ke.compile (&fusion, {input0});
1819+ auto outputs = ke.run ({input0});
1820+
1821+ testValidate (&fusion, outputs, {input0}, __LINE__, __FILE__);
1822+ }
1823+
16691824} // namespace nvfuser
0 commit comments