Skip to content
16 changes: 9 additions & 7 deletions csrc/device_lower/analysis/tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ std::unordered_set<const Expr*> getBatchableTmaLoads(
// We have some tests where TMA load is used in an untraditional way.
// e.g. parallelized with threads, serial load, which requires multiple
// mbarriers or reuse of the same mbarrier.
if (std::any_of(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
[](const IterDomain* id) {
return id->isThreadDim() ||
id->getParallelType() == ParallelType::Serial;
})) {
auto non_trivial_ids =
tv->getLoopDomain() | std::views::filter([](const IterDomain* id) {
return !id->extent()->isConstScalar() ||
id->extent()->evaluate().as<int64_t>() > 1;
});
if (std::ranges::any_of(non_trivial_ids, [](const IterDomain* id) {
return id->isThreadDim() ||
id->getParallelType() == ParallelType::Serial;
})) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trivial optimization of multiple-tma loads, doesn't have to be in this PR.

return {};
}
non_cb_tma_load_exprs.push_back(expr);
Expand Down
3 changes: 2 additions & 1 deletion csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ auto parseEnvOptions(
available_options.end(),
std::back_inserter(option_values),
[](const auto& kv) { return kv.first; });
std::sort(option_values.begin(), option_values.end());
std::ranges::sort(option_values);
NVF_CHECK(
false,
"Parsing ",
Expand Down Expand Up @@ -174,6 +174,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"tma_pointwise", EnableOption::TmaPointwise},
{"tma_inner_persistent", EnableOption::TmaInnerPersistent},
{"tma_reduction", EnableOption::TmaReduction},
{"tma_transpose", EnableOption::TmaTranspose},
{"ws_normalization", EnableOption::WarpSpecializedNormalization},
{"host_ir_lowering", EnableOption::HostIrLowering},
{"host_ir_jit", EnableOption::HostIrJit},
Expand Down
19 changes: 11 additions & 8 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <visibility.h>

#include <algorithm>
#include <cstdint>
#include <mutex>
#include <string>
#include <unordered_map>
Expand All @@ -22,7 +23,7 @@ namespace nvfuser {
//!
//! These can be set through the `NVFUSER_DUMP` environment variable
//!
enum class DebugDumpOption {
enum class DebugDumpOption : std::uint8_t {
CutlassCompile, //!< Dump compile commands and compile times for
//!< CutlassExecutor
FunctionTrace, //!< Dump the function trace of selected internal function. The
Expand Down Expand Up @@ -97,7 +98,7 @@ enum class DebugDumpOption {
//!
//! These can be set through the `NVFUSER_ENABLE` environment variable
//!
enum class EnableOption {
enum class EnableOption : std::uint8_t {
CutlassScheduler, //! Enable the CUTLASS scheduler and executor
FuseMatmul, //! Enable automatic fusion of matmul and linear ops
FuseMultipleMatmuls, //! Allow fusing more than one matmul in a single kernel
Expand All @@ -118,6 +119,7 @@ enum class EnableOption {
TmaPointwise, //! Enable TMA pointwise kernel
TmaInnerPersistent, //! Enable TMA inner persistent kernel
TmaReduction, //! Enable TMA reduction kernel
TmaTranspose, //! Enable TMA transpose kernel
WarpSpecializedNormalization, //! Enable warp specialized persistent kernel
HostIrLowering, //! Enable FusionKernelRuntime lowering to host IR
HostIrJit, //! Enable Host IR JIT compilation with LLVM
Expand All @@ -134,7 +136,7 @@ enum class EnableOption {
//!
//! These can be set through the `NVFUSER_DISABLE` environment variable
//!
enum class DisableOption {
enum class DisableOption : std::uint8_t {
CompileToSass, //! Disable direct compilation to sass so the ptx can be
//! examined
ContigIndexing, //! Disable contiguous indexing
Expand Down Expand Up @@ -176,7 +178,7 @@ enum class DisableOption {
//!
//! These can be set through the `NVFUSER_PROF` environment variable
//!
enum class ProfilerOption {
enum class ProfilerOption : std::uint8_t {
Enable, //! Enables the profiler.
EnableNocupti, //! Enables the profiler, but disables CUPTI specific
//! profiling inorder to measure true host time without
Expand All @@ -197,10 +199,11 @@ class Options {
public:
Options() : options_(getOptionsFromEnv()) {}

Options(const Options& other) {
std::lock_guard<std::mutex> lock_other(other.mutex_);
options_ = other.options_;
}
Options(const Options& other)
: options_([&other]() {
std::lock_guard<std::mutex> lock_other(other.mutex_);
return other.options_;
}()) {}

Options& operator=(const Options& other) {
std::lock_guard<std::mutex> lock_other(other.mutex_);
Expand Down
10 changes: 6 additions & 4 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,11 @@ std::unique_ptr<HeuristicParams> TransposeScheduler::computeHeuristics(

std::unique_ptr<TransposeParams> tparams = nullptr;

// Try TMA path first
tparams =
transpose::tma::getTransposeHeuristics(fusion, runtime_info, data_cache);
// Try TMA path first if enabled
if (isOptionEnabled(EnableOption::TmaTranspose)) {
tparams = transpose::tma::getTransposeHeuristics(
fusion, runtime_info, data_cache);
}

// Fallback to non-TMA scheduler if TMA is not applicable
if (tparams == nullptr) {
Expand All @@ -431,7 +433,7 @@ void TransposeScheduler::schedule(
"Incorrect parameters sent to TransposeScheduler::schedule",
params);

if (tparams->use_tma_load) {
if (tparams->use_tma_load || tparams->use_tma_store) {
transpose::tma::scheduleTranspose(fusion, tparams);
} else {
transpose::non_tma::scheduleTranspose(fusion, tparams);
Expand Down
30 changes: 30 additions & 0 deletions csrc/scheduler/transpose_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ class TransposeParams : public HeuristicParams {

// Whether to use TMA for loading inputs
bool use_tma_load = false;
bool use_tma_store = false;

// Which side of shared memory holds the transposed (swizzled) layout.
// false: input smem is swizzled, transpose happens on smem->register read.
// true: output smem is swizzled, transpose happens on register->smem write.
// This is independent of use_tma_load/use_tma_store — TMA can be used for
// either side regardless of where the transpose swizzle lives.
bool is_output_smem_transpose = false;

// In 128-bytes swizzled tma load, inner most dim is split into 8 chunks each
// with 16 bytes. Each thread may handle multiple chunks along the inner most
// dim.
int64_t chunks_per_thread = 1;
int64_t elements_per_chunk = 1;

// Vectorization factor for tensors in the first group
int64_t vectorize_factor1 = 1;
Expand All @@ -65,6 +79,10 @@ class TransposeParams : public HeuristicParams {
}
bool attr_equal = other->cparams == cparams &&
other->use_tma_load == use_tma_load &&
other->use_tma_store == use_tma_store &&
other->is_output_smem_transpose == is_output_smem_transpose &&
other->chunks_per_thread == chunks_per_thread &&
other->elements_per_chunk == elements_per_chunk &&
other->split_before_tiling == split_before_tiling &&
other->dims_merged_with_1 == dims_merged_with_1 &&
other->dims_merged_with_2 == dims_merged_with_2 &&
Expand Down Expand Up @@ -99,6 +117,14 @@ class TransposeParams : public HeuristicParams {
if (unroll_factor2 > 1) {
ss << "Unroll group 2, Factor: " << unroll_factor2 << "\n";
}
if (use_tma_load || use_tma_store) {
ss << "TMA: load=" << (use_tma_load ? "true" : "false")
<< " store=" << (use_tma_store ? "true" : "false")
<< " is_output_smem_transpose="
<< (is_output_smem_transpose ? "true" : "false")
<< " chunks_per_thread=" << chunks_per_thread
<< " elements_per_chunk=" << elements_per_chunk << "\n";
}
if (!split_before_tiling.empty() || !dims_merged_with_1.empty() ||
!dims_merged_with_2.empty()) {
ss << "Virtual inner-most dim:\n";
Expand Down Expand Up @@ -146,6 +172,10 @@ class TransposeParams : public HeuristicParams {
size_t hash() const override {
return c10::get_hash(
use_tma_load,
use_tma_store,
is_output_smem_transpose,
chunks_per_thread,
elements_per_chunk,
split_before_tiling,
dims_merged_with_1,
dims_merged_with_2,
Expand Down
Loading
Loading