Skip to content

Commit c402a65

Browse files
committed
[EXP] Mask off out of bound dot products
WARNING: does not successfully pass the entire unit test sweep. Does not affect existing functionality, and only fails to pass with the same error threshold as existing use cases (without clipping)
1 parent 3b61d7c commit c402a65

File tree

5 files changed

+106
-59
lines changed

5 files changed

+106
-59
lines changed

csrc/include/natten/cuda/fna/epilogue/epilogue_rescale_output.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ class MemoryEfficientAttentionNormalize {
179179
multiplies<ComputeFragment> mul_add_source;
180180
multiply_add<ComputeFragment> mul_add_accumulator;
181181

182-
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
182+
auto s_prime = s_prime_[row];
183+
auto scale = s_prime == 0 ? 0 : 1 / s_prime;
184+
ElementCompute alpha = isLast ? scale : 1;
183185
ElementCompute beta = alpha * m_prime_[row];
184186

185187
intermediate = mul_add_source(beta, converted_source); // X = beta * C
@@ -209,7 +211,9 @@ class MemoryEfficientAttentionNormalize {
209211
ComputeFragment intermediate;
210212
multiplies<ComputeFragment> mul_accumulator;
211213

212-
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
214+
auto s_prime = s_prime_[row];
215+
auto scale = s_prime == 0 ? 0 : 1 / s_prime;
216+
ElementCompute alpha = isLast ? scale : 1;
213217

214218
intermediate = mul_accumulator(
215219
alpha, converted_accumulator); // X = alpha * C + uniform

csrc/include/natten/cuda/fna/kernel_backward.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,30 +1639,34 @@ struct FusedNeighborhoodAttentionBackwardKernel {
16391639
auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
16401640
lane_id, warp_id, output_tile_coords);
16411641

1642-
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
1643-
// SCALING.
1642+
// Dot product scale
1643+
accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
1644+
1645+
// (Optional) clip dot products (mask off out of bound dot products)
16441646
if (p.has_dot_product_clip) {
16451647
if (not p.has_dot_product_max) {
16461648
for (int i = 0; i < Mma::FragmentC::kElements; ++i) {
1647-
accum[i] = cutlass::fast_max(accum[i], p.dot_product_min);
1649+
accum[i] = accum[i] < p.dot_product_min
1650+
? -cutlass::platform::numeric_limits<accum_t>::infinity()
1651+
: accum[i];
16481652
}
16491653
} else if (not p.has_dot_product_min) {
16501654
for (int i = 0; i < Mma::FragmentC::kElements; ++i) {
1651-
accum[i] = cutlass::fast_min(accum[i], p.dot_product_max);
1655+
accum[i] = accum[i] > p.dot_product_max
1656+
? -cutlass::platform::numeric_limits<accum_t>::infinity()
1657+
: accum[i];
16521658
}
16531659
} else {
16541660
// assert(p.has_dot_product_min && p.has_dot_product_max);
16551661
for (int i = 0; i < Mma::FragmentC::kElements; ++i) {
1656-
accum[i] = cutlass::fast_max(
1657-
cutlass::fast_min(accum[i], p.dot_product_max),
1658-
p.dot_product_min);
1662+
accum[i] =
1663+
(accum[i] < p.dot_product_min || accum[i] > p.dot_product_max)
1664+
? -cutlass::platform::numeric_limits<accum_t>::infinity()
1665+
: accum[i];
16591666
}
16601667
}
16611668
}
16621669

1663-
// Dot product scale
1664-
accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
1665-
16661670
if (not p.is_fully_block_sparse) {
16671671
// Neighborhood Attention masking
16681672
Dim first_col, query_bound, row_idx;

csrc/include/natten/cuda/fna/kernel_forward.h

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -745,23 +745,30 @@ struct FusedNeighborhoodAttentionKernel {
745745
MM1::Mma::drain_cp_asyncs();
746746
}
747747

748-
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
749-
// SCALING.
748+
// (Optional) clip dot products (mask off out of bound dot products)
750749
if (p.has_dot_product_clip) {
751750
if (not p.has_dot_product_max) {
752751
for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) {
753-
accum[i] = cutlass::fast_max(accum[i], p.dot_product_min);
752+
accum[i] = accum[i] * p.scale;
753+
accum[i] = accum[i] < p.dot_product_min
754+
? -cutlass::platform::numeric_limits<accum_t>::infinity()
755+
: accum[i];
754756
}
755757
} else if (not p.has_dot_product_min) {
756758
for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) {
757-
accum[i] = cutlass::fast_min(accum[i], p.dot_product_max);
759+
accum[i] = accum[i] * p.scale;
760+
accum[i] = accum[i] > p.dot_product_max
761+
? -cutlass::platform::numeric_limits<accum_t>::infinity()
762+
: accum[i];
758763
}
759764
} else {
760765
// assert(p.has_dot_product_min && p.has_dot_product_max);
761766
for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) {
762-
accum[i] = cutlass::fast_max(
763-
cutlass::fast_min(accum[i], p.dot_product_max),
764-
p.dot_product_min);
767+
accum[i] = accum[i] * p.scale;
768+
accum[i] =
769+
(accum[i] < p.dot_product_min || accum[i] > p.dot_product_max)
770+
? -cutlass::platform::numeric_limits<accum_t>::infinity()
771+
: accum[i];
765772
}
766773
}
767774
}
@@ -823,7 +830,7 @@ struct FusedNeighborhoodAttentionKernel {
823830
last_kv_col,
824831
is_first_kv_iter,
825832
iteratorC_tile_offset,
826-
p.scale);
833+
p.has_dot_product_clip ? 1.0 : p.scale);
827834

828835
// Output results to shared-memory
829836

@@ -999,8 +1006,13 @@ struct FusedNeighborhoodAttentionKernel {
9991006
map_index_to_coord((int32_t)thread_id(), problem_size_0_m);
10001007
auto query_offset = (query_idx * p.lse_strideM).sum();
10011008
if (is_coord_within_upper_bound(query_idx, problem_size_0_m)) {
1002-
p.logsumexp_ptr[query_offset] = accum_t(mi[thread_id()] / kLog2e) +
1003-
cutlass::fast_log(accum_t(s_prime[thread_id()]));
1009+
if (mi[thread_id()] ==
1010+
-cutlass::platform::numeric_limits<accum_t>::infinity()) {
1011+
p.logsumexp_ptr[query_offset] = 0.0f;
1012+
} else {
1013+
p.logsumexp_ptr[query_offset] = accum_t(mi[thread_id()] / kLog2e) +
1014+
cutlass::fast_log(accum_t(s_prime[thread_id()]));
1015+
}
10041016
//} else if (query_offset < lse_dim) {
10051017
// p.logsumexp_ptr[query_offset] =
10061018
// cutlass::platform::numeric_limits<accum_t>::infinity();

csrc/include/natten/cuda/reference/fna_reference_backward.hpp

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -116,23 +116,30 @@ void __global__ fna_bwd_reference_dQ_kernel(
116116
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
117117
} // for idx_D1
118118

119-
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
120-
// SCALING.
119+
acc_qk *= attn_scale;
120+
acc_dov *= attn_scale;
121+
acc_doo *= attn_scale;
122+
123+
// (Optional) clip dot products (mask off out of bound dot products)
121124
if (has_dot_product_min || has_dot_product_max) {
122125
if (not has_dot_product_max) {
123-
acc_qk = cutlass::fast_max(acc_qk, dot_product_min);
126+
acc_qk = acc_qk < dot_product_min
127+
? -cutlass::platform::numeric_limits<
128+
ElementAccumulator>::infinity()
129+
: acc_qk;
124130
} else if (not has_dot_product_min) {
125-
acc_qk = cutlass::fast_min(acc_qk, dot_product_max);
131+
acc_qk = acc_qk > dot_product_max
132+
? -cutlass::platform::numeric_limits<
133+
ElementAccumulator>::infinity()
134+
: acc_qk;
126135
} else {
127-
acc_qk = cutlass::fast_max(
128-
cutlass::fast_min(acc_qk, dot_product_max), dot_product_min);
136+
acc_qk = (acc_qk < dot_product_min || acc_qk > dot_product_max)
137+
? -cutlass::platform::numeric_limits<
138+
ElementAccumulator>::infinity()
139+
: acc_qk;
129140
}
130141
}
131142

132-
acc_qk *= attn_scale;
133-
acc_dov *= attn_scale;
134-
acc_doo *= attn_scale;
135-
136143
auto id = make_identity_tensor(make_shape(1, 1));
137144
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
138145
frag(0) = acc_qk;
@@ -246,23 +253,30 @@ void __global__ fna_bwd_reference_dK_kernel(
246253
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
247254
} // for idx_D1
248255

249-
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
250-
// SCALING.
256+
acc_qk *= attn_scale;
257+
acc_dov *= attn_scale;
258+
acc_doo *= attn_scale;
259+
260+
// (Optional) clip dot products (mask off out of bound dot products)
251261
if (has_dot_product_min || has_dot_product_max) {
252262
if (not has_dot_product_max) {
253-
acc_qk = cutlass::fast_max(acc_qk, dot_product_min);
263+
acc_qk = acc_qk < dot_product_min
264+
? -cutlass::platform::numeric_limits<
265+
ElementAccumulator>::infinity()
266+
: acc_qk;
254267
} else if (not has_dot_product_min) {
255-
acc_qk = cutlass::fast_min(acc_qk, dot_product_max);
268+
acc_qk = acc_qk > dot_product_max
269+
? -cutlass::platform::numeric_limits<
270+
ElementAccumulator>::infinity()
271+
: acc_qk;
256272
} else {
257-
acc_qk = cutlass::fast_max(
258-
cutlass::fast_min(acc_qk, dot_product_max), dot_product_min);
273+
acc_qk = (acc_qk < dot_product_min || acc_qk > dot_product_max)
274+
? -cutlass::platform::numeric_limits<
275+
ElementAccumulator>::infinity()
276+
: acc_qk;
259277
}
260278
}
261279

262-
acc_qk *= attn_scale;
263-
acc_dov *= attn_scale;
264-
acc_doo *= attn_scale;
265-
266280
auto id = make_identity_tensor(make_shape(1, 1));
267281
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
268282
frag(0) = acc_qk;
@@ -374,21 +388,28 @@ void __global__ fna_bwd_reference_dV_kernel(
374388
acc_qk += rQ * rK;
375389
} // for idx_D0
376390

377-
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
378-
// SCALING.
391+
acc_qk *= attn_scale;
392+
393+
// (Optional) clip dot products (mask off out of bound dot products)
379394
if (has_dot_product_min || has_dot_product_max) {
380395
if (not has_dot_product_max) {
381-
acc_qk = cutlass::fast_max(acc_qk, dot_product_min);
396+
acc_qk = acc_qk < dot_product_min
397+
? -cutlass::platform::numeric_limits<
398+
ElementAccumulator>::infinity()
399+
: acc_qk;
382400
} else if (not has_dot_product_min) {
383-
acc_qk = cutlass::fast_min(acc_qk, dot_product_max);
401+
acc_qk = acc_qk > dot_product_max
402+
? -cutlass::platform::numeric_limits<
403+
ElementAccumulator>::infinity()
404+
: acc_qk;
384405
} else {
385-
acc_qk = cutlass::fast_max(
386-
cutlass::fast_min(acc_qk, dot_product_max), dot_product_min);
406+
acc_qk = (acc_qk < dot_product_min || acc_qk > dot_product_max)
407+
? -cutlass::platform::numeric_limits<
408+
ElementAccumulator>::infinity()
409+
: acc_qk;
387410
}
388411
}
389412

390-
acc_qk *= attn_scale;
391-
392413
auto id = make_identity_tensor(make_shape(1, 1));
393414
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
394415
frag(0) = acc_qk;

csrc/include/natten/cuda/reference/fna_reference_forward.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,20 +138,26 @@ void __global__ fna_reference_kernel(
138138
acc += eQ * eK;
139139
}
140140

141-
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
142-
// SCALING.
141+
acc = acc * attn_scale;
142+
143+
// (Optional) clip dot products (mask off out of bound dot products)
143144
if (has_dot_product_min || has_dot_product_max) {
144145
if (not has_dot_product_max) {
145-
acc = cutlass::fast_max(acc, dot_product_min);
146+
acc = acc < dot_product_min ? -cutlass::platform::numeric_limits<
147+
ElementAccumulator>::infinity()
148+
: acc;
146149
} else if (not has_dot_product_min) {
147-
acc = cutlass::fast_min(acc, dot_product_max);
150+
acc = acc > dot_product_max ? -cutlass::platform::numeric_limits<
151+
ElementAccumulator>::infinity()
152+
: acc;
148153
} else {
149-
acc = cutlass::fast_max(
150-
cutlass::fast_min(acc, dot_product_max), dot_product_min);
154+
acc = (acc < dot_product_min || acc > dot_product_max)
155+
? -cutlass::platform::numeric_limits<
156+
ElementAccumulator>::infinity()
157+
: acc;
151158
}
152159
}
153160

154-
acc = acc * attn_scale;
155161
auto frag = make_tensor<ElementAccumulator>(Shape<_1, _1>{});
156162
frag(0) = acc;
157163
attention_mask.apply_mask(
@@ -212,17 +218,17 @@ void __global__ fna_reference_kernel(
212218
__syncthreads();
213219
}
214220

221+
ElementAccumulator scale = sum == 0.0f ? 0.0f : 1.0f / sum;
215222
for (int i = 0; i < DimPerThread; ++i) {
216223
int idx_D = threadIdx.x + i * blockDim.x;
217224
if (idx_D < size<1>(mO)) {
218-
ElementAccumulator scale = 1.0f / sum;
219225
mO(idx_Q + offset_Q, idx_D, idx_L) =
220226
static_cast<typename TensorO::value_type>(final_acc[i] * scale);
221227
}
222228
}
223229

224230
if (threadIdx.x == 0) {
225-
mLSE(idx_Q + offset_Q, idx_L) = log(sum) + maxS;
231+
mLSE(idx_Q + offset_Q, idx_L) = sum == 0.0f ? 0.0f : (log(sum) + maxS);
226232
}
227233
}
228234
}

0 commit comments

Comments
 (0)