Skip to content

Commit 03b1fe0

Browse files
committed
memory efficient 1/2
1 parent 4e1e4cb commit 03b1fe0

File tree

9 files changed

+1179
-234
lines changed

9 files changed

+1179
-234
lines changed

.ci/scripts/test_model_e2e.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ case "$HF_MODEL" in
126126
esac
127127

128128
echo "::group::Setup ExecuTorch Requirements"
129-
./install_requirements.sh
129+
# ./install_requirements.sh
130130
pip list
131131
echo "::endgroup::"
132132

backends/aoti/aoti_backend.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ def preprocess(
214214
with open(so_path, "rb") as f:
215215
so_data = f.read()
216216

217+
print("so_path: ", so_path)
218+
217219
# Read weights blob
218220
with open(blob_path, "rb") as f:
219221
blob_data = f.read()
@@ -229,9 +231,9 @@ def preprocess(
229231
method_name + "_weights_blob", blob_data, 1, weights_blob_data_type
230232
)
231233

232-
# Clean up the generated files
233-
os.remove(so_path)
234-
os.remove(blob_path)
234+
# # Clean up the generated files
235+
# os.remove(so_path)
236+
# os.remove(blob_path)
235237

236238
return PreprocessResult(
237239
processed_bytes=b"",

backends/cuda/CMakeLists.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ install(
9898
set(_aoti_cuda_shim_sources
9999
runtime/shims/memory.cpp runtime/shims/tensor_attribute.cpp
100100
runtime/guard.cpp runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu
101+
runtime/shims/sdpa.cu
101102
${EXECUTORCH_ROOT}/backends/aoti/common_shims.cpp
102103
)
103104

@@ -130,12 +131,12 @@ target_link_options(
130131
aoti_cuda_shims PUBLIC $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
131132
)
132133

133-
# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and
134+
# Link against CUDA::cudart, CUDA::cublas, common AOTI library, cuda_tensor_maker, and
134135
# platform utilities
135136
target_link_libraries(
136137
aoti_cuda_shims
137-
PRIVATE cuda_platform
138-
PUBLIC extension_tensor cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS}
138+
PRIVATE cuda_platform executorch_core
139+
PUBLIC extension_tensor cuda_tensor_maker CUDA::cudart CUDA::cublas ${CMAKE_DL_LIBS}
139140
)
140141

141142
if(NOT MSVC)

backends/cuda/cuda_backend.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -137,20 +137,20 @@ def get_aoti_compile_options(
137137

138138
return options
139139

140-
@classmethod
141-
def get_extra_aoti_compile_context_manager(cls):
142-
"""
143-
Return SDPA MATH backend context manager for CUDA compilation.
144-
145-
This context manager plays as a fallback solution for any remaining PyTorch SDPA
146-
operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.
147-
148-
Note:
149-
- If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
150-
this context manager will have no effect on those ops (they are no longer
151-
PyTorch SDPA ops).
152-
- If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
153-
context manager will force them to use the MATH backend, causing them to
154-
be automatically decomposed during compilation.
155-
"""
156-
return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])
140+
# @classmethod
141+
# def get_extra_aoti_compile_context_manager(cls):
142+
# """
143+
# Return SDPA MATH backend context manager for CUDA compilation.
144+
145+
# This context manager plays as a fallback solution for any remaining PyTorch SDPA
146+
# operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.
147+
148+
# Note:
149+
# - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
150+
# this context manager will have no effect on those ops (they are no longer
151+
# PyTorch SDPA ops).
152+
# - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
153+
# context manager will force them to use the MATH backend, causing them to
154+
# be automatically decomposed during compilation.
155+
# """
156+
# return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])

0 commit comments

Comments
 (0)