Skip to content

Commit 3ac302f

Browse files
authored
manual schedule of a transpose in output cached smem (#6008)
This PR adds a manual scheduling test case demonstrating how to perform a transpose on cached output shared memory using a TMA store. The transpose scheduler may choose to apply the transpose on either cached input or cached output, depending on the number of inputs and outputs. The guiding principle is to minimize the total number of required transposes, e.g. will do output transpose when there are more inputs than outputs.
1 parent 0cfc24b commit 3ac302f

File tree

1 file changed

+156
-1
lines changed

1 file changed

+156
-1
lines changed

tests/cpp/test_transpose.cpp

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1486,7 +1486,11 @@ TEST_F(TransposeTest, NoTransposeMaverick17B) {
14861486
class TransposeTMA
14871487
: public TransposeTest,
14881488
public testing::WithParamInterface<std::tuple<bool, bool>> {};
1489-
TEST_P(TransposeTMA, TransposeTMALoadOptionalStore) {
1489+
1490+
// Transpose happens at input cached smem
1491+
// Each thread loads multiple columns from input and write as multiple rows to
1492+
// output.
1493+
TEST_P(TransposeTMA, TransposeInputSmem) {
14901494
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
14911495

14921496
Fusion fusion;
@@ -1666,4 +1670,155 @@ INSTANTIATE_TEST_SUITE_P(
16661670
std::string(
16671671
output_smem_swizzle ? "OutputSwizzle" : "NoOutputSwizzle");
16681672
}));
1673+
1674+
// Transpose happens at output cached smem
1675+
// Each thread loads multiple rows from input and write as multiple columns to
1676+
// output.
1677+
// with tma load, 852 ms on GB200.
1678+
// without tma load, 814 ms on GB200.
1679+
TEST_F(TransposeTMA, TransposeOutputSmem) {
1680+
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
1681+
const bool use_tma_load = false;
1682+
1683+
Fusion fusion;
1684+
FusionGuard fg(&fusion);
1685+
1686+
auto input = makeContigTensor(2);
1687+
fusion.addInput(input);
1688+
auto output = transpose(input, 0, 1);
1689+
fusion.addOutput(output);
1690+
1691+
TensorView* input_smem_cache = nullptr;
1692+
TensorView* input_reg_cache = nullptr;
1693+
if (use_tma_load) {
1694+
input_smem_cache =
1695+
input->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
1696+
input_smem_cache->setMemoryType(MemoryType::Shared);
1697+
input_reg_cache = input_smem_cache->cacheAfter();
1698+
} else {
1699+
input_reg_cache = input->cacheAfter();
1700+
}
1701+
TensorView* output_smem_cache = nullptr;
1702+
output_smem_cache =
1703+
output->cacheBefore(LoadStoreOpType::CpAsyncBulkTensorTile);
1704+
output_smem_cache->setMemoryType(MemoryType::Shared);
1705+
output_smem_cache->cacheBefore();
1706+
1707+
// gmem --> smem -> regs
1708+
// These tvs follow input layout [I0, I1].
1709+
TensorView* ref_tv = input_reg_cache;
1710+
1711+
// Swizzle parameters: 128-byte TMA swizzle with 16-byte chunks.
1712+
constexpr int64_t tma_swizzle_bytes = 128;
1713+
constexpr int64_t swizzle_chunk_bytes = 16;
1714+
const int64_t dtype_bytes =
1715+
dataTypeSizeByte(output_smem_cache->getDataType().value());
1716+
const int64_t elements_per_chunk = swizzle_chunk_bytes / dtype_bytes;
1717+
// tile_i1 must equal tma_swizzle_bytes / dtype_bytes.
1718+
const int64_t tile_i0 = 32;
1719+
NVF_ERROR_EQ(
1720+
tile_i0,
1721+
tma_swizzle_bytes / dtype_bytes,
1722+
"tile_i0 should span exactly 128 bytes");
1723+
// tile_i1 and chunks_per_thread are tunable parameters.
1724+
const int64_t tile_i1 = 64;
1725+
const int64_t chunks_per_thread = 2;
1726+
1727+
// Step 1: Tile all tvs by tile_i0 (I0 dim) and tile_i1 (I1 dim).
1728+
// ref_tv has input layout [I0, I1].
1729+
// Target loop domain: [BIDx, tile_i0, tile_i1]
1730+
ref_tv->split(1, tile_i1);
1731+
ref_tv->split(0, tile_i0);
1732+
// [I0/tile_i0, tile_i0, I1/tile_i1, tile_i1]
1733+
ref_tv->reorder({{-2, 1}});
1734+
ref_tv->merge(0);
1735+
// [I0/tile_i0 * I1/tile_i1, tile_i0, tile_i1]
1736+
ref_tv->axis(0)->parallelize(ParallelType::BIDx);
1737+
// [BIDx, tile_i0, tile_i1]
1738+
1739+
{
1740+
TransformPropagator propagator(ref_tv);
1741+
MaxLogicalDomainInfoSpanningTree entire_dag(ref_tv);
1742+
entire_dag.traverse(&propagator);
1743+
scheduler_utils::parallelizeAllLike(
1744+
ref_tv,
1745+
/*selected_tvs=*/{},
1746+
/*selected_parallel_types=*/{},
1747+
/*propagate_padding=*/true,
1748+
/*parallelize_inputs_on_did=*/true);
1749+
}
1750+
// After propagation, all tvs have loop domain: [BIDx, tile_i0, tile_i1]
1751+
// ref_tv uses input layout (logical: [I0, I1]).
1752+
// input_smem_cache uses input layout (logical: [I0, I1]).
1753+
1754+
// Step 2: Schedule output TMA store (Bulk parallel on tile dims).
1755+
// output smem has output layout, [I1, I0]
1756+
// reorder to move inner most alloc dim to innermost position in loop
1757+
// domain
1758+
// [BIDx, tile_i0, tile_i1]
1759+
output_smem_cache->reorder({{-1, -2}});
1760+
// [BIDx, tile_i1, tile_i0]
1761+
MmaInputSmemSwizzle swizzle =
1762+
mma_utils::tmaSwizzleSharedMemory(output_smem_cache);
1763+
mma_utils::scheduleTMAStoreForMmaOutput(output_smem_cache, swizzle);
1764+
mma_utils::scheduleTMAStoreForMmaOutput(output, swizzle);
1765+
1766+
// Step 3: Schedule input shared memory
1767+
// no swizzle, just contiguous load
1768+
// [BIDx, tile_i0, tile_i1]
1769+
if (use_tma_load) {
1770+
input_smem_cache->axis(1)->parallelize(ParallelType::Bulk);
1771+
input_smem_cache->axis(2)->parallelize(ParallelType::Bulk);
1772+
input_smem_cache->setAllocationDomain(
1773+
input_smem_cache->getLoopDomain(), true);
1774+
}
1775+
// Step 4: Schedule register tvs for per-thread access pattern.
1776+
// [BIDx, tile_i0, tile_i1]
1777+
ref_tv->split(-2, elements_per_chunk);
1778+
// [BIDx, tile_i0/chunk, chunk, tile_i1]
1779+
ref_tv->split(-3, chunks_per_thread);
1780+
// [BIDx, tile_i0/chunk/cpt, cpt, chunk, tile_i1]
1781+
ref_tv->merge(-4, -1);
1782+
// [BIDx, tile_i1/chunk/cpt * tile_i0, cpt, chunk]
1783+
ref_tv->axis(-3)->parallelize(ParallelType::TIDx);
1784+
// [BIDx, TIDx, cpt, chunk]
1785+
ref_tv->axis(-1)->parallelize(ParallelType::Unroll);
1786+
1787+
{
1788+
// Propagate Step 4 transforms to all tvs except those already
1789+
// independently scheduled (input_smem_cache and TMA store output).
1790+
std::unordered_set<TensorView*> skip_tvs;
1791+
if (use_tma_load) {
1792+
skip_tvs.insert(input_smem_cache);
1793+
}
1794+
skip_tvs.insert(output);
1795+
auto propagate_tvs = ir_utils::allTvsExcept(&fusion, skip_tvs);
1796+
std::unordered_set<TensorView*> propagate_tvs_set(
1797+
propagate_tvs.begin(), propagate_tvs.end());
1798+
SetSelector selector(propagate_tvs_set);
1799+
MaxLogicalDomainInfoSpanningTree propagate_dag(ref_tv, &selector);
1800+
TransformPropagator propagator(ref_tv);
1801+
propagate_dag.traverse(&propagator);
1802+
scheduler_utils::parallelizeAllLike(
1803+
ref_tv,
1804+
/*selected_tvs=*/{propagate_tvs},
1805+
/*selected_parallel_types=*/{},
1806+
/*propagate_padding=*/true,
1807+
/*parallelize_inputs_on_did=*/true);
1808+
}
1809+
1810+
output_smem_cache->axis(-1)->parallelize(ParallelType::Vectorize);
1811+
1812+
inlineMost();
1813+
1814+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1815+
at::Tensor input0 = at::randn({16384, 32768}, options);
1816+
1817+
KernelExecutor ke;
1818+
ke.compile(&fusion, {input0});
1819+
auto outputs = ke.run({input0});
1820+
1821+
testValidate(&fusion, outputs, {input0}, __LINE__, __FILE__);
1822+
}
1823+
16691824
} // namespace nvfuser

0 commit comments

Comments
 (0)