Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d10fa92
Initial commit
Micky774 Oct 24, 2025
eef7dc0
Updated to build from source by default
Micky774 Oct 24, 2025
cc68ab7
Updated for V3 API
Micky774 Oct 31, 2025
4455361
Fixed build, reverted AOTriton bwd changes (now V2)
Micky774 Nov 3, 2025
2586b18
Removed alterations
Micky774 Nov 3, 2025
aa80f81
Removed lazy tensor wrapper
Micky774 Nov 3, 2025
9a91b9e
Streamlined cmakelist, other PR review feedback adressed
Micky774 Nov 4, 2025
023deb4
Removed `pad_between_seqs`
Micky774 Nov 4, 2025
6b8dbe5
Updated typing to be more explicit
Micky774 Nov 4, 2025
68303d0
Minor streamlining and formatting
Micky774 Nov 4, 2025
8181972
Initial implementation
Micky774 Nov 6, 2025
6788a16
Simplified window size func for current non-SWA support
Micky774 Nov 6, 2025
182101a
Removed accidental include
Micky774 Nov 6, 2025
19a9c0f
Merge branch 'zain/aotriton' into zain/aotriton-bwd
Micky774 Nov 6, 2025
fef6baa
Corrected bwd args
Micky774 Nov 6, 2025
3a4fab8
Updated causal window default
Micky774 Nov 10, 2025
917e3c3
Updated window values for causal
Micky774 Nov 10, 2025
ce32e3b
Merge branch 'zain/aotriton' into zain/aotriton-bwd
Micky774 Nov 10, 2025
36045c8
Corrected DQ_ACC buffer, added env var for GPU kernel building
Micky774 Nov 12, 2025
d6e46c1
Update AOTriton to 0.11.1b
Micky774 Nov 12, 2025
1349a48
Merge branch 'dev' into zain/aotriton
Micky774 Nov 24, 2025
8ed0009
Merge branch 'zain/aotriton' into zain/aotriton-bwd
Micky774 Nov 24, 2025
2bd9006
Added AOTriton commit SHA
Micky774 Nov 25, 2025
a9bef37
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Nov 25, 2025
0fdff86
Moved handling of env variable to makefile
Micky774 Nov 26, 2025
3f6e054
Simplified lazy tensor implementation
Micky774 Dec 1, 2025
2246da4
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ cmake_minimum_required(VERSION 3.21)

option(USE_ROCM "Use ROCm" ON)
option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON)
option(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS "Build AOTriton GPU kernels" OFF)
option(USE_FUSED_ATTN_CK "Use ck backend" ON)
set(USE_CUDA OFF)

Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
fused_attn_aotriton_bwd_qkvpacked(
b, h, max_seqlen, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO, output_S,
output_dQKV,
Expand Down Expand Up @@ -678,6 +679,7 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_aotriton_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO,
output_S,
Expand Down Expand Up @@ -858,6 +860,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_aotriton_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO,
output_S,
Expand Down
158 changes: 127 additions & 31 deletions transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,32 +296,52 @@ void fused_attn_aotriton_fwd_impl(
NVTE_CHECK_CUDA(attn_fwd(fwd_params, fwd_params.kVersion, stream));
}

// A thin conversion wrapper around eager tensor-views to lazy tensors
template<int kRank>
struct LazyTensorFunctions {
static aotriton::TensorView<kRank> acquire(void* cookie) {
return *static_cast<aotriton::TensorView<kRank>*>(cookie);
}
static void dispose(void* cookie) {
}
};

void fused_attn_aotriton_bwd_impl(
uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d,
float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout,
int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrO, void* devPtrSoftmaxAux,
void* devPtrdQ, void* devPtrdK, void* devPtrdV,
void* devPtrdO,
void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV,
const uint64_t* devPtrDropoutSeed,
const uint64_t* devPtrDropoutOffset,
aotriton::DType dtype,
void *workspace,
size_t *workspace_size,
cudaStream_t stream) {

const uint64_t dq_acc_size = b*s_q*h*d*sizeof(float);

// Exit to request upper level API to allocate memory if needed
if(workspace==nullptr){
// CK only requires workspace for lse softmax
// AOTriton requires workspace for lse softmax
*workspace_size = b*h*s_q*sizeof(float);
// AOTriton requires workspace for DQ_ACC
*workspace_size += dq_acc_size;
return;
}
void * delta = workspace;
workspace = static_cast<void *>(static_cast<int8_t *>(workspace) + b*h*s_q*sizeof(float));
void * dq_acc_ptr = workspace;

std::array<uint64_t, 4> q_stride;
std::array<uint64_t, 4> k_stride;
std::array<uint64_t, 4> v_stride;
std::array<uint64_t, 4> o_stride;
std::array<uint64_t, 4> dq_acc_stride;
generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(),
Expand All @@ -330,14 +350,17 @@ void fused_attn_aotriton_bwd_impl(
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
// AOTriton expects a BSHD layout DQ_ACC matrix
generateMatrixStrides(b, h, s_q, s_kv, d, dq_acc_stride.data(),
NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD, NVTE_QKV_Matrix::NVTE_Q_Matrix);

//q and o are having the same shape
//k and v are having the same shape
//x and dx are having the same shape and stride
std::array<uint64_t, 4> q_shape{b, h, s_q, d};
std::array<uint64_t, 4> kv_shape{b, hg, s_kv, d};

// m and workspace are of the same shape and stride
// m and softmax_lse are of the same shape and stride
std::array<uint64_t, 2> m_shape{b * h, s_q};
std::array<uint64_t, 2> m_stride{s_q, 1};

Expand All @@ -355,52 +378,116 @@ void fused_attn_aotriton_bwd_impl(

// auxilary tensors
auto M_tensor = aotriton::TensorView<2>(reinterpret_cast<intptr_t>(devPtrSoftmaxAux), m_shape, m_stride, aotriton::DType::kFloat32);
auto wkspace_tensor = aotriton::TensorView<2>(reinterpret_cast<intptr_t>(workspace), m_shape, m_stride, aotriton::DType::kFloat32);
auto delta_tensor = aotriton::TensorView<2>(reinterpret_cast<intptr_t>(delta), m_shape, m_stride, aotriton::DType::kFloat32);
auto dq_acc_tensor = aotriton::TensorView<4>(reinterpret_cast<intptr_t>(dq_acc_ptr), q_shape, dq_acc_stride, aotriton::DType::kFloat32);
NVTE_CHECK_CUDA(hipMemsetAsync(dq_acc_ptr, 0, dq_acc_size, stream));

auto dq_acc_lazy = aotriton::LazyTensor<4> {
.cookie = &dq_acc_tensor,
.acquire = &LazyTensorFunctions<4>::acquire,
.dispose = &LazyTensorFunctions<4>::dispose
};
auto delta_lazy = aotriton::LazyTensor<2> {
.cookie = &delta_tensor,
.acquire = &LazyTensorFunctions<2>::acquire,
.dispose = &LazyTensorFunctions<2>::dispose
};

// Cumulative seqlen tensors
std::array<uint64_t, 1> cu_seqlens_shape{b+1};
std::array<uint64_t, 1> cu_seqlens_stride{1};
auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast<intptr_t>(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32);
auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast<intptr_t>(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32);

bool nvte_log_aotriton_config = false;
if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) {
if (env_p != nullptr && std::string(env_p) == "1")
nvte_log_aotriton_config = true;
}
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype);
using aotriton::v2::flash::attn_bwd;
auto seed = mk_aoscalartensor(devPtrDropoutSeed);
auto offset = mk_aoscalartensor(devPtrDropoutOffset);
const auto is_causal = mask_type == NVTE_CAUSAL_MASK;

using aotriton::v3::flash::VarlenType;
int8_t varlen_type = VarlenType::None;

auto [window_left, window_right] = get_window_sizes(window_size_left, window_size_right, is_causal);
using aotriton::v3::flash::CausalType;
int8_t causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;

if (nvte_log_aotriton_config) {
std::cout<<std::endl<<"attn_bwd(aotriton): ";
std::cout<<"q_shape: ("<<b<<", "<<h<<", "<<s_q<<", "<<d<<"), ";
std::cout<<"q_stride: ("<<q_stride[0]<<", "<<q_stride[1]<<", "<<q_stride[2]<<", "<<q_stride[3]<<"), ";
std::cout<<"kv_shape: ("<<b<<", "<<hg<<", "<<s_kv<<", "<<d<<"), ";
std::cout<<"k_stride: ("<<k_stride[0]<<", "<<k_stride[1]<<", "<<k_stride[2]<<", "<<k_stride[3]<<"), ";
std::cout<<"v_stride: ("<<v_stride[0]<<", "<<v_stride[1]<<", "<<v_stride[2]<<", "<<v_stride[3]<<"), ";
std::cout<<"scaling_factor: "<<scaling_factor<<", ";
std::cout<<"M_shape: ("<<b*h<<", "<<s_q<<"), ";
std::cout<<"M_stride: ("<<s_q<<", "<<1<<"), ";
std::cout<<"o_shape: ("<<b<<", "<<h<<", "<<s_q<<", "<<d<<"), ";
std::cout<<"o_stride: ("<<o_stride[0]<<", "<<o_stride[1]<<", "<<o_stride[2]<<", "<<o_stride[3]<<"), ";
std::cout<<"dropout_p: "<<dropout_probability<<", ";
std::cout<<"causal mask: "<<(mask_type==NVTE_CAUSAL_MASK)<<std::endl;

std::cout<< "\nAOTriton attn_bwd_params:\n";
std::cout<<"Q: "<<q_tensor.data_ptr()<<"\n";
std::cout<<"K: "<<k_tensor.data_ptr()<<"\n";
std::cout<<"V: "<<v_tensor.data_ptr()<<"\n";
std::cout<<"B: "<<empty_bias.data_ptr()<<"\n";
std::cout<<"Sm_scale: "<<scaling_factor<<"\n";
std::cout<<"Out: "<<o_tensor.data_ptr()<<"\n";
std::cout<<"cu_seqlens_q: "<<cu_seqlens_q.data_ptr()<<"\n";
std::cout<<"cu_seqlens_k: "<<cu_seqlens_k.data_ptr()<<"\n";
std::cout<<"Max_seqlen_q: "<<s_q<<"\n";
std::cout<<"Max_seqlen_k: "<<s_kv<<"\n";
std::cout<<"DO: "<<do_tensor.data_ptr()<<"\n";
std::cout<<"DK: "<<dk_tensor.data_ptr()<<"\n";
std::cout<<"DV: "<<dv_tensor.data_ptr()<<"\n";
std::cout<<"DQ: "<<dq_tensor.data_ptr()<<"\n";
std::cout<<"DB: "<<empty_bias.data_ptr()<<"\n";
std::cout<<"L: "<<M_tensor.data_ptr()<<"\n";
std::cout<<"D: "<<delta_tensor.data_ptr()<<"\n";
std::cout<<"dropout_p: "<<dropout_probability<<"\n";
std::cout<<"philox_seed_ptr: "<<seed.data_ptr()<<"\n";
std::cout<<"philox_offset1: "<<offset.data_ptr()<<"\n";
std::cout<<"philox_offset2: "<<0<<"\n";
std::cout<<"causal_type: "<<+causal_type<<"\n";
std::cout<<"varlen_type: "<<+varlen_type<<"\n";
std::cout<<"window_left: "<<window_left<<"\n";
std::cout<<"window_right: "<<window_right<<"\n";
std::cout<<"DQ_ACC: "<<dq_acc_tensor.data_ptr()<<"\n";
}
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype);
using aotriton::v2::flash::attn_bwd;
auto seed = mk_aoscalartensor(devPtrDropoutSeed);
auto offset = mk_aoscalartensor(devPtrDropoutOffset);
const auto is_causal = mask_type == NVTE_CAUSAL_MASK;
NVTE_CHECK_CUDA(attn_bwd(q_tensor,
k_tensor,
v_tensor,
empty_bias,
scaling_factor,
o_tensor,
do_tensor,
dq_tensor,
dk_tensor,
dv_tensor,
empty_bias,
M_tensor,
wkspace_tensor,
dropout_probability,
seed,
offset,
0,
is_causal,
stream));
aotriton::v3::flash::attn_bwd_params bwd_params{};
bwd_params.Q = q_tensor;
bwd_params.K = k_tensor;
bwd_params.V = v_tensor;
bwd_params.B = empty_bias;
bwd_params.Sm_scale = scaling_factor;
bwd_params.Out = o_tensor;
if(varlen_type){
bwd_params.cu_seqlens_q = cu_seqlens_q;
bwd_params.cu_seqlens_k = cu_seqlens_k;
bwd_params.Max_seqlen_q = s_q;
bwd_params.Max_seqlen_k = s_kv;
}
bwd_params.DO = do_tensor;
bwd_params.DK = dk_tensor;
bwd_params.DV = dv_tensor;
bwd_params.DQ = dq_tensor;
bwd_params.DB = empty_bias;
bwd_params.L = M_tensor;
bwd_params.D = delta_lazy;
bwd_params.dropout_p = dropout_probability;
bwd_params.philox_seed_ptr = seed;
bwd_params.philox_offset1 = offset;
bwd_params.philox_offset2 = 0;
bwd_params.causal_type = causal_type;
bwd_params.varlen_type = varlen_type;
bwd_params.window_left = window_left;
bwd_params.window_right = window_right;
bwd_params.DQ_ACC = dq_acc_lazy;

NVTE_CHECK_CUDA(attn_bwd(bwd_params, bwd_params.kVersion, stream));
}
#endif // USE_FUSED_ATTN_AOTRITON
} // namespace fused_attn_rocm
Expand Down Expand Up @@ -495,6 +582,7 @@ void fused_attn_aotriton_fwd_qkvpacked(
void fused_attn_aotriton_bwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d,
float attn_scale, float dropout,
int32_t window_size_left, int32_t window_size_right,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO,
const Tensor* output_S,
Expand Down Expand Up @@ -532,12 +620,14 @@ void fused_attn_aotriton_bwd_qkvpacked(
fused_attn_aotriton_bwd_impl(
b, h, h, max_seqlen, max_seqlen, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout,
bias_type, attn_mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV,
devPtrdO,
input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr,
reinterpret_cast<const uint64_t *>(rng_state->data.dptr),
reinterpret_cast<const uint64_t *>(rng_state->data.dptr) + 1,
nvte_to_aotriton_dtype(QKV_type),
Expand Down Expand Up @@ -652,6 +742,7 @@ void fused_attn_aotriton_fwd_kvpacked(
void fused_attn_aotriton_bwd_kvpacked(
size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
float attn_scale, float dropout,
int32_t window_size_left, int32_t window_size_right,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO,
const Tensor* output_S,
Expand Down Expand Up @@ -692,12 +783,14 @@ void fused_attn_aotriton_bwd_kvpacked(
fused_attn_aotriton_bwd_impl(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout,
bias_type, attn_mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV,
devPtrdO,
input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr,
reinterpret_cast<const uint64_t *>(rng_state->data.dptr),
reinterpret_cast<const uint64_t *>(rng_state->data.dptr) + 1,
nvte_to_aotriton_dtype(QKV_type),
Expand Down Expand Up @@ -803,6 +896,7 @@ void fused_attn_aotriton_fwd(
void fused_attn_aotriton_bwd(
size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
float attn_scale, float dropout,
int32_t window_size_left, int32_t window_size_right,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO,
const Tensor* output_S,
Expand Down Expand Up @@ -831,12 +925,14 @@ void fused_attn_aotriton_bwd(
fused_attn_aotriton_bwd_impl(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout,
bias_type, attn_mask_type,
devPtrQ, devPtrK, devPtrV,
devPtrO, devPtrSoftmaxStats,
devPtrdQ, devPtrdK, devPtrdV,
devPtrdO,
input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr,
reinterpret_cast<const uint64_t *>(rng_state->data.dptr),
reinterpret_cast<const uint64_t *>(rng_state->data.dptr) + 1,
nvte_to_aotriton_dtype(QKV_type),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ void fused_attn_aotriton_fwd_qkvpacked(
void fused_attn_aotriton_bwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d,
float attn_scale, float dropout,
int32_t window_size_left, int32_t window_size_right,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO,
const Tensor* output_S,
Expand All @@ -72,6 +73,7 @@ void fused_attn_aotriton_fwd_kvpacked(
void fused_attn_aotriton_bwd_kvpacked(
size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
float attn_scale, float dropout,
int32_t window_size_left, int32_t window_size_right,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO,
const Tensor* output_S,
Expand All @@ -98,6 +100,7 @@ void fused_attn_aotriton_fwd(
void fused_attn_aotriton_bwd(
size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
float attn_scale, float dropout,
int32_t window_size_left, int32_t window_size_right,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO,
const Tensor* output_S,
Expand Down