Skip to content

Commit a271a6d

Browse files
bottlermeta-codesync[bot]
authored andcommitted
Move hip_fmha op schemas next to their implementations (#245)
Summary: Pull Request resolved: #245 When `MSLK_BUILD_HIP_FMHA=0`, `attention.cpp` was still compiled into the main library and registered schemas for `efficient_attention_forward_ck`, `efficient_attention_backward_ck`, and `_ck_rand_uniform` via `TORCH_LIBRARY_FRAGMENT`, even though the `TORCH_LIBRARY_IMPL` bindings and kernel implementations (in `mslk_hip_fmha`) were absent. This left unimplemented operators registered in the library — the op appears in the dispatcher but calling it fails. Fix by moving each `m.def` into the same file as its `TORCH_LIBRARY_IMPL`, inside `hip_fmha/`. Since those files are only compiled as part of the `mslk_hip_fmha` static library, schema and implementation now come and go together. The decoder ops remain in `attention.cpp` since their situation differs. Reviewed By: cthi Differential Revision: D97933992 fbshipit-source-id: 121f9ce6c707288a6923fd1d62a5611a90659cf6
1 parent 714d498 commit a271a6d

File tree

4 files changed

+21
-11
lines changed

4 files changed

+21
-11
lines changed

csrc/attention/ck/fmha/attention.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,14 @@ PyMODINIT_FUNC PyInit__C(void) {
2222

2323
TORCH_LIBRARY_FRAGMENT(xformers, m) {
2424
#if defined(USE_ROCM)
25-
m.def(TORCH_SELECTIVE_SCHEMA(
26-
"xformers::efficient_attention_forward_ck(Tensor query, "
27-
"Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, "
28-
"Tensor? seqstart_k, int? max_seqlen_q, float dropout_p, "
29-
"bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size, Tensor? block_tables, int? page_size) -> (Tensor, Tensor?, int, int)"));
25+
// Schemas for ops whose implementations live in hip_fmha/ are registered
26+
// there, alongside their TORCH_LIBRARY_IMPL, so that they are absent from
27+
// builds where hip_fmha is not compiled (e.g. MSLK_BUILD_HIP_FMHA=0).
3028
m.def(TORCH_SELECTIVE_SCHEMA(
3129
"xformers::efficient_attention_forward_decoder_ck(Tensor query, "
3230
"Tensor key, Tensor value, Tensor? seq_positions, float scale) -> Tensor"));
3331
m.def(TORCH_SELECTIVE_SCHEMA(
3432
"xformers::efficient_attention_forward_decoder_splitk_ck(Tensor query, Tensor key, "
3533
" Tensor value, Tensor? seq_positions, float scale, int split_k) -> Tensor"));
36-
#ifndef FMHA_OMIT_BACKWARD
37-
m.def(TORCH_SELECTIVE_SCHEMA(
38-
"xformers::efficient_attention_backward_ck(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, Tensor? seqstart_k, int? max_seqlen_q, int? max_seqlen_k, Tensor? seqlen_k, Tensor logsumexp, Tensor output, float dropout_p, int rng_seed, int rng_offset, int custom_mask_type, float? scale, int? window_size) -> (Tensor, Tensor, Tensor, Tensor)"));
39-
#endif
40-
m.def(TORCH_SELECTIVE_SCHEMA(
41-
"xformers::_ck_rand_uniform(float p, Tensor out) -> Tensor"));
4234
#endif
4335
}

csrc/attention/ck/fmha/hip_fmha/attention_backward_generic_ck_tiled.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,11 @@ efficient_attention_backward_ck_meta(
631631

632632
} // namespace
633633

634+
TORCH_LIBRARY_FRAGMENT(xformers, m) {
635+
m.def(TORCH_SELECTIVE_SCHEMA(
636+
"xformers::efficient_attention_backward_ck(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, Tensor? seqstart_k, int? max_seqlen_q, int? max_seqlen_k, Tensor? seqlen_k, Tensor logsumexp, Tensor output, float dropout_p, int rng_seed, int rng_offset, int custom_mask_type, float? scale, int? window_size) -> (Tensor, Tensor, Tensor, Tensor)"));
637+
}
638+
634639
TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
635640
m.impl(
636641
TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_ck"),

csrc/attention/ck/fmha/hip_fmha/attention_ck_rand_uniform.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ at::Tensor rand_uniform_int(
9494

9595
} // namespace
9696

97+
TORCH_LIBRARY_FRAGMENT(xformers, m) {
98+
m.def(TORCH_SELECTIVE_SCHEMA(
99+
"xformers::_ck_rand_uniform(float p, Tensor out) -> Tensor"));
100+
}
101+
97102
TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
98103
m.impl(
99104
TORCH_SELECTIVE_NAME("xformers::_ck_rand_uniform"),

csrc/attention/ck/fmha/hip_fmha/attention_forward_generic_ck_tiled.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,14 @@ efficient_attention_forward_ck_meta(
524524

525525
} // namespace
526526

527+
TORCH_LIBRARY_FRAGMENT(xformers, m) {
528+
m.def(TORCH_SELECTIVE_SCHEMA(
529+
"xformers::efficient_attention_forward_ck(Tensor query, "
530+
"Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, "
531+
"Tensor? seqstart_k, int? max_seqlen_q, float dropout_p, "
532+
"bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size, Tensor? block_tables, int? page_size) -> (Tensor, Tensor?, int, int)"));
533+
}
534+
527535
TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
528536
m.impl(
529537
TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_ck"),

0 commit comments

Comments
 (0)