Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e90b991
[ROCm] manually pick up fwd native padding support from Meekail's PR
wangye805 Oct 16, 2025
9d02d52
Initial update
Micky774 Oct 16, 2025
81bac35
Updated stride
Micky774 Oct 16, 2025
54ee86a
Corrected typing in allocation portions
Micky774 Oct 16, 2025
47a7cab
Applied Ye's patch
Micky774 Oct 17, 2025
0e0064f
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 Oct 20, 2025
945ab5b
[ROCm] jax use runtime segment
wangye805 Oct 21, 2025
579b592
[ROCm] get runtime max_seqlen as well
wangye805 Oct 22, 2025
73247d9
[ROCm] support v2 bwd native padding
wangye805 Oct 22, 2025
7e1c3ef
Updated conversion to include bwd pass
Micky774 Oct 22, 2025
51090d3
Merge branch 'yewang12/te_aiter_native_padding_bwd' into zain/aiter-b…
Micky774 Oct 23, 2025
0e121ba
Added BWD BSHD-->THD conversion and minor logic refactor
Micky774 Oct 23, 2025
734692d
Corrected softmax lse bug
Micky774 Oct 23, 2025
5c24188
Updated logic flow and re-caclulation
Micky774 Oct 23, 2025
b59d466
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 Oct 20, 2025
97073fe
Merge branch 'zain/aiter-bwd-bshd-thd' into zain/aiter-native-bshd-thd
Micky774 Oct 28, 2025
f27a99f
Added env var guard
Micky774 Oct 28, 2025
d757aef
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Oct 28, 2025
33c5912
Updated ptr variables and streamlined dispatch
Micky774 Oct 29, 2025
af57290
Added env guard
Micky774 Oct 29, 2025
bc8f4a7
Corrected bshd_to_thd conversion arguments
Micky774 Oct 29, 2025
b7f2cf8
Corrected logical flow
Micky774 Oct 30, 2025
3e48a02
Guarded memset and corrected allocation
Micky774 Nov 5, 2025
b1094c6
Remove V3 API check and guard memsets
Micky774 Nov 5, 2025
c3a0fce
PR comments
Micky774 Nov 6, 2025
9ab8df4
Updated documentation
Micky774 Nov 10, 2025
2adfb6e
PR review reconciliation
Micky774 Nov 10, 2025
bb3868d
Added explicit test
Micky774 Nov 12, 2025
52c8167
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Nov 12, 2025
6206d58
Formatting for bwd debug
Micky774 Nov 13, 2025
0582851
Resolved error when using mixed formats e.g. sbhd_2bshd
Micky774 Nov 14, 2025
78716de
Updated guard on flash-attention forced support
Micky774 Nov 14, 2025
85bb6f6
Added check for SBHD_2BSHD
Micky774 Nov 14, 2025
a12105d
Added guard on dk/dv memset
Micky774 Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,16 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa
to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case
for both the `FusedAttention` and `DotProductAttention` modules.

Certain settings can be enabled to potentially optimize workloads depending on the nature of the inputs and expected outputs:

* NVTE_CK_RUNTIME_NUM_SEGMENTS - by default 0, if set to 1 then the JAX integration will calculate the number of
segments at runtime. Enabling this requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`.
* NVTE_CK_RUNTIME_MAX_SEQLEN - by default 0, if set to 1 then the max sequence length will be calculated at runtime.
This can result in speedups in cases where there are many zero-length sequences. Enabling this while using the JAX
integration requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`.
* NVTE_CK_ZERO_OUT_PAD - by default 1, if set to 0 then the output of the FA forward pass will not be initialized
to zero, meaning invalid regions (representing padding) may take nonzero values. Only used if input has padding.

FA v3 Kernels in CK Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ROCm TE provides experimental support for flash-attention v3 fwd/bwd kernels using the ck backend for limited fused attention configs.
Expand Down
22 changes: 22 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,27 @@ def test():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.")
def test_gqa_mla_thd():
"""
Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration
post-processing for BWD FA with native padding support.
"""
config = ModelConfig(8, 16, 4, 128, 128, 128, 0.0, "padding", "no_bias", head_dim_v=64)
qkv_layout = "thd_thd_thd"
dtype = torch.float16
_, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=True,
)
if FusedAttnBackend["CK"] not in fused_attn_backends:
pytest.skip("This test requires the CK fused attention backend.")

test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False)

@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.")
def test_dot_product_mem_calc():
"""
Expand Down Expand Up @@ -368,6 +389,7 @@ def test_dot_product_attention(
and config.attn_mask_type in ["causal", "padding_causal"]
)
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
and not is_mla
):
flash_attn_supported = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ hipError_t ck_attn_fwd(
uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o,
void* lse_ptr,
bool uses_fwd_v3,
int how_v3_bf16_cvt,
hipStream_t stream);

hipError_t ck_attn_varlen_fwd(
Expand All @@ -72,6 +73,7 @@ hipError_t ck_attn_varlen_fwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
bool is_training,
float scaling_factor,
float dropout_probability,
Expand All @@ -82,6 +84,7 @@ hipError_t ck_attn_varlen_fwd(
uint64_t stride_h_o, uint64_t stride_s_o,
void* lse_thd_ptr,
bool uses_fwd_v3,
int how_v3_bf16_cvt,
hipStream_t stream);

hipError_t ck_attn_bwd(
Expand Down Expand Up @@ -137,6 +140,7 @@ hipError_t ck_attn_varlen_bwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
const void* o_ptr,
uint64_t stride_h_o, uint64_t stride_s_o,
const void* lse_thd_ptr,
Expand Down
132 changes: 99 additions & 33 deletions transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@

namespace ck_fused_attn{

// TODO: unify with binary search in TE/common/fused_attn(rocm)/util
// no device std::upper_bound
// in an increasing array with given size len, search for the index that:
// array[index] <= target < array[index+1]
// guaranteed that target >=0 and target <= cu_seqlen[end-1]
__forceinline__ __device__ int binary_search(int32_t target, const int32_t *array, uint64_t len) {
int left = 1, right = len - 1;
while (left < right) {
int mid = (left + right) / 2;
if (array[mid] <= target) {
left = mid + 1;
} else {
right = mid;
}
}
return left - 1;
}

// define dk_dv_reduce function only for fp16 and bf16 types
template<typename DataType>
__global__ void dk_dv_reduce(
Expand Down Expand Up @@ -109,8 +127,9 @@ __global__ void dk_or_dv_reduce(
// define dk_dv_reduce function in THD layout only for fp16 and bf16 types
template<typename DataType>
__global__ void dk_dv_reduce_thd(
uint64_t h, uint64_t hg, uint64_t d,
const int32_t* total_seqlen_kv_ptr,
uint64_t b, uint64_t h, uint64_t hg, uint64_t d,
const int32_t* cu_seqlen_kv_ptr,
const int32_t* cu_seqlen_kv_padded_ptr,
const DataType *dk_expanded,
const DataType *dv_expanded,
uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded,
Expand All @@ -124,11 +143,17 @@ __global__ void dk_dv_reduce_thd(
uint64_t hdim_idx = threadIdx.x;

assert(hdim_idx<d);

if(seqlen_idx >= *total_seqlen_kv_ptr){
if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){
return;
}

if(cu_seqlen_kv_padded_ptr){
uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1);
uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx];
if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){
return;
}
}
// h guaranteed to be multiples of hg
uint64_t head_idx_offset = h / hg;

Expand Down Expand Up @@ -164,8 +189,9 @@ __global__ void dk_dv_reduce_thd(
// When d_qk != d_v, we need to reduce dk and dv separately
template<typename DataType>
__global__ void dk_or_dv_reduce_thd(
uint64_t h, uint64_t hg, uint64_t d,
const int32_t* total_seqlen_kv_ptr,
uint64_t b, uint64_t h, uint64_t hg, uint64_t d,
const int32_t* cu_seqlen_kv_ptr,
const int32_t* cu_seqlen_kv_padded_ptr,
const DataType *dk_or_dv_expanded,
uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded,
DataType *dk_or_dv,
Expand All @@ -178,10 +204,16 @@ __global__ void dk_or_dv_reduce_thd(

assert(hdim_idx<d);

if(seqlen_idx >= *total_seqlen_kv_ptr){
if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){
return;
}

if(cu_seqlen_kv_padded_ptr){
uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1);
uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx];
if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){
return;
}
}
// h guaranteed to be multiples of hg
uint64_t head_idx_offset = h / hg;

Expand Down Expand Up @@ -323,7 +355,7 @@ void log_bwd_config(const char* func_name,
std::cout<<std::endl<<func_name<<std::endl;

// fmha_traits debug
std::cout<<"fmha_traits: "<<std::endl;
std::cout<<std::endl<<"fmha_traits: "<<std::endl;
std::cout<<"hdim_q: "<<fmha_args.hdim_q<<std::endl;
std::cout<<"hdim_v: "<<fmha_args.hdim_v<<std::endl;
std::cout<<"data_type: "<<data_type_str<<std::endl;
Expand All @@ -339,7 +371,7 @@ void log_bwd_config(const char* func_name,
std::cout<<"how_v3_bf16_cvt: "<<how_v3_bf16_cvt<<std::endl;

// fmha_args debug
std::cout<<"fmha_args: "<<std::endl;
std::cout<<std::endl<<"fmha_args: "<<std::endl;
std::cout<<"q_ptr: "<<fmha_args.q_ptr<<std::endl;
std::cout<<"k_ptr: "<<fmha_args.k_ptr<<std::endl;
std::cout<<"v_ptr: "<<fmha_args.v_ptr<<std::endl;
Expand All @@ -353,9 +385,15 @@ void log_bwd_config(const char* func_name,
std::cout<<"dk_ptr: "<<fmha_args.dk_ptr<<std::endl;
std::cout<<"dv_ptr: "<<fmha_args.dv_ptr<<std::endl;
std::cout<<"dbias_ptr: "<<fmha_args.dbias_ptr<<std::endl;
std::cout<<"dq_acc_ptr: "<<fmha_args.dq_acc_ptr<<std::endl;

std::cout<<"seqstart_q_ptr: "<<fmha_args.seqstart_q_ptr<<std::endl;
std::cout<<"seqstart_k_ptr: "<<fmha_args.seqstart_k_ptr<<std::endl;
std::cout<<"seqlen_q_ptr: "<<fmha_args.seqlen_q_ptr<<std::endl;
std::cout<<"seqlen_k_ptr: "<<fmha_args.seqlen_k_ptr<<std::endl;
std::cout<<"cu_seqlen_q_ptr: "<<fmha_args.cu_seqlen_q_ptr<<std::endl;
std::cout<<"cu_seqlen_k_ptr: "<<fmha_args.cu_seqlen_k_ptr<<std::endl;

std::cout<<"seqlen_q: "<<fmha_args.seqlen_q<<std::endl;
std::cout<<"seqlen_k: "<<fmha_args.seqlen_k<<std::endl;
std::cout<<"batch: "<<fmha_args.batch<<std::endl;
Expand Down Expand Up @@ -572,9 +610,12 @@ hipError_t ck_attn_bwd(
is_mqa_gqa? dv_expanded_ptr:dv_ptr,
has_dbias? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr): nullptr,
dq_acc_ptr, //dq_acc_buf
nullptr,//cu_seqlen_q
nullptr,//cu_seqlen_kv
nullptr,//seqstart_q_ptr
nullptr,//seqstart_k_ptr
nullptr, /* seqlen_q_ptr */
nullptr, /* seqlen_k_ptr */
nullptr, //cu_seqlen_q_ptr
nullptr, //cu_seqlen_k_ptr
shape_seqlen_q,
shape_seqlen_k,
batch,
Expand Down Expand Up @@ -784,6 +825,7 @@ hipError_t ck_attn_varlen_bwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
const void* o_ptr,
uint64_t stride_h_o, uint64_t stride_s_o,
const void* lse_thd_ptr,
Expand Down Expand Up @@ -915,11 +957,14 @@ hipError_t ck_attn_varlen_bwd(
dq_ptr,
is_mqa_gqa? dk_expanded_ptr:dk_ptr,
is_mqa_gqa? dv_expanded_ptr:dv_ptr,
nullptr,
nullptr, //dbias_ptr
dq_acc_ptr, //dq_acc_buf
cu_seqlen_q_ptr,//cu_seqlen_q
cu_seqlen_kv_ptr,//cu_seqlen_kv
cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr
cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr
nullptr, /* seqlen_q_ptr */
nullptr, /* seqlen_k_ptr */
cu_seqlen_q_ptr, //cu_seqlen_q_ptr
cu_seqlen_kv_ptr, //cu_seqlen_k_ptr
max_seqlen_q, //seqlen_q, unused in group mode
max_seqlen_k, //seqlen_kv, unused in group mode
batch,
Expand Down Expand Up @@ -977,6 +1022,18 @@ hipError_t ck_attn_varlen_bwd(
std::pair<const void*, const void*>{philox_seed_ptr, philox_offset_ptr}};
}();

// modify the max_seqlen_q for better performance in 0-length cases
// lse_thd_ptr used as buffer
if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) {
if(std::string(env_p) == "1"){
if(ck_fused_attn_log_config){
std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.";
}
fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream);
fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream);
}
}

// print ck traits and args when needed
log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args);
if (uses_bwd_v3)
Expand All @@ -985,17 +1042,17 @@ hipError_t ck_attn_varlen_bwd(
}

float average_runtime = aiter::mha_bwd(fmha_args,
stream_config,
data_type_str,
is_group_mode,
mask_type,
bias_enum::no_bias,
has_dbias,
s_randval,
deterministic,
uses_bwd_v3,
is_v3_atomic_fp32,
how_v3_bf16_cvt);
stream_config,
data_type_str,
is_group_mode,
mask_type,
bias_enum::no_bias,
has_dbias,
s_randval,
deterministic,
uses_bwd_v3,
is_v3_atomic_fp32,
how_v3_bf16_cvt);
if(average_runtime < 0){
//TODO: better error out system
throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass.");
Expand All @@ -1006,6 +1063,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block(d_qk);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_dv_reduce_thd: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl;
std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl;
std::cout<<"stride_h_dkv_expanded: "<<stride_h_dk_expanded<<std::endl;
Expand All @@ -1018,8 +1077,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_dv_reduce_thd<CK_TILE_TYPE>, grid, block, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dk_expanded_ptr),
static_cast<CK_TILE_TYPE*>(dv_expanded_ptr),
stride_h_dk_expanded, stride_s_dk_expanded,
Expand All @@ -1030,6 +1090,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block_dk(d_qk);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dk: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl;
std::cout<<"stride_h_dk_expanded: "<<stride_h_dk_expanded<<std::endl;
std::cout<<"stride_s_dk_expanded: "<<stride_s_dk_expanded<<std::endl;
Expand All @@ -1040,8 +1102,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dk, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have gqa/mqa + MLA testcases w and w/ padding? If not, can we create those to verify this flow is actually working

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will work on trying to add one in the JAX side -- for now I've added one on the TE side that isn't able to run due to too few backends supporting it, but that may change e.g. as we update AOTriton

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Then let's skip the pytorch side gqa/mqa + MLA test for now. You can put a to-do here and add it later when other backends support it

static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dk_expanded_ptr),
stride_h_dk_expanded, stride_s_dk_expanded,
static_cast<CK_TILE_TYPE*>(dk_ptr),
Expand All @@ -1050,6 +1113,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block_dv(d_v);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dv: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl;
std::cout<<"stride_h_dv_expanded: "<<stride_h_dv_expanded<<std::endl;
std::cout<<"stride_s_dv_expanded: "<<stride_s_dv_expanded<<std::endl;
Expand All @@ -1060,8 +1125,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dv, 0, stream,
h, hg, d_v,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_v,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dv_expanded_ptr),
stride_h_dv_expanded, stride_s_dv_expanded,
static_cast<CK_TILE_TYPE*>(dv_ptr),
Expand Down
Loading