Skip to content

Commit 54f1cc6

Browse files
authored
Add Neural Accelerator Support (#2772)
1 parent b3825ac commit 54f1cc6

22 files changed

+7290
-47
lines changed

mlx/backend/metal/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ if(NOT MLX_METAL_PATH)
121121
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
122122
endif()
123123

124+
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
125+
26.2))
126+
set(MLX_ENABLE_NAX TRUE)
127+
target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX)
128+
else()
129+
set(MLX_ENABLE_NAX FALSE)
130+
endif()
131+
124132
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
125133

126134
target_compile_definitions(mlx

mlx/backend/metal/device.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,4 +265,14 @@ Device& device(mlx::core::Device);
265265

266266
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
267267

268+
#ifdef MLX_ENABLE_NAX
269+
270+
inline bool is_nax_available() {
271+
static bool is_nax_available_ =
272+
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
273+
return is_nax_available_;
274+
}
275+
276+
#endif // MLX_ENABLE_NAX
277+
268278
} // namespace mlx::core::metal

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ set(BASE_HEADERS
99
utils.h)
1010

1111
function(build_kernel_base TARGET SRCFILE DEPS)
12-
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
12+
set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
1313
if(MLX_METAL_DEBUG)
1414
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
1515
endif()
16+
if(MLX_ENABLE_NAX)
17+
set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
18+
endif()
1619
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
1720
set(METAL_FLAGS ${METAL_FLAGS}
1821
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
@@ -120,6 +123,30 @@ if(NOT MLX_METAL_JIT)
120123
build_kernel(gemv_masked steel/utils.h)
121124
endif()
122125

126+
if(MLX_ENABLE_NAX)
127+
128+
set(STEEL_NAX_HEADERS
129+
steel/defines.h
130+
steel/utils.h
131+
steel/gemm/transforms.h
132+
steel/gemm/nax.h
133+
steel/gemm/gemm_nax.h
134+
steel/utils/type_traits.h
135+
steel/utils/integral_constant.h)
136+
137+
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
138+
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
139+
140+
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
141+
build_kernel(fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS})
142+
143+
set(STEEL_NAX_ATTN_HEADERS
144+
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
145+
steel/utils/integral_constant.h)
146+
147+
build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS})
148+
endif()
149+
123150
add_custom_command(
124151
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
125152
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o

0 commit comments

Comments
 (0)