Skip to content

Commit f511425

Browse files
sssshhhhhha2d8a4v
authored andcommitted
Add causal flag to fa2 (#1976)
1 parent 6c9aeda commit f511425

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

include/ctranslate2/ops/flash_attention.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace ctranslate2 {
66
namespace ops {
77
class FlashAttention : public Op {
88
public:
9-
FlashAttention(float queries_scale, dim_t sliding_window);
9+
FlashAttention(float queries_scale, dim_t sliding_window, bool is_causal = true);
1010

1111
void operator()(StorageView& queries,
1212
StorageView& keys,
@@ -25,6 +25,7 @@ namespace ctranslate2 {
2525
private:
2626
const float _queries_scale;
2727
const dim_t _sliding_window;
28+
const bool _is_causal;
2829
template <Device D>
2930
void compute(StorageView& queries,
3031
StorageView& keys,

src/ops/flash_attention.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
namespace ctranslate2 {
66
namespace ops {
7-
FlashAttention::FlashAttention(float queries_scale, dim_t sliding_window)
7+
FlashAttention::FlashAttention(float queries_scale, dim_t sliding_window, bool is_causal)
88
: _queries_scale(queries_scale)
9-
,_sliding_window(sliding_window)
9+
, _sliding_window(sliding_window)
10+
, _is_causal(is_causal)
1011
{
1112
}
1213

src/ops/flash_attention_gpu.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ namespace ctranslate2 {
232232
num_heads_k = cached_keys->dim(2);
233233
}
234234

235+
bool is_causal = _is_causal;
235236
// causal=true is the same as causal=false in this case
236-
bool is_causal = true;
237237
if (seqlen_q == 1 && !alibi) { is_causal = false; }
238238
if (is_causal) { window_size_right = 0; }
239239

0 commit comments

Comments
 (0)