Skip to content

Commit 64ca7ab

Browse files
committed
[Executorch][BE] Rename sdpa_with_kv_cache.py to custom_ops.py
Because now we have more than sdpa_with_kv_cache in it Differential Revision: [D66269486](https://our.internmc.facebook.com/intern/diff/D66269486/) ghstack-source-id: 254678999 Pull Request resolved: #6996
1 parent 5688320 commit 64ca7ab

File tree

8 files changed

+7
-6
lines changed

8 files changed

+7
-6
lines changed

examples/models/llama/runner/native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from executorch.examples.models.llama.runner.generation import LlamaRunner
2424

2525
# Note: import this after portable_lib
26-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
26+
import executorch.extension.llm.custom_ops # noqa # usort: skip
2727
from executorch.kernels import quantized # noqa
2828

2929

examples/models/llama/source_transformation/sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
9999

100100

101101
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
102-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa
102+
import executorch.extension.llm.custom_ops # noqa
103103

104104
_replace_sdpa_with_custom_op(module)
105105
return module

examples/models/llava/test/test_llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from executorch.extension.pybindings.portable_lib import (
1919
_load_for_executorch_from_buffer,
2020
)
21-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
21+
import executorch.extension.llm.custom_ops # noqa # usort: skip
2222
from executorch.kernels import quantized # noqa # usort: skip
2323

2424
logging.basicConfig(level=logging.INFO)

examples/models/llava/test/test_pte.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from PIL import Image
1515

1616
# Custom ops has to be loaded after portable_lib.
17-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
17+
import executorch.extension.llm.custom_ops # noqa # usort: skip
1818
from executorch.kernels import quantized # noqa # usort: skip
1919

2020
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"

extension/llm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ A sampler class in C++ to sample the logistics given some hyperparameters.
3838
## custom_ops
3939
Contains custom op, such as:
4040
- custom sdpa: implements CPU flash attention and avoids copies by taking the kv cache as one of its arguments.
41-
- _sdpa_with_kv_cache.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration.
41+
- _custom_ops.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration.
4242
- _op_sdpa.cpp_: the optimized operator implementation and registration of _sdpa_with_kv_cache.out_.
4343

4444
## runner
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .custom_ops import *

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn.functional as F
1313

14-
from .sdpa_with_kv_cache import custom_ops_lib # noqa
14+
from .custom_ops import custom_ops_lib # noqa
1515

1616

1717
def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len):

0 commit comments

Comments
 (0)