Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_BUILD_ROCM "Build ROCm backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
Expand Down Expand Up @@ -88,6 +89,10 @@ if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()

if(MLX_BUILD_ROCM)
enable_language(HIP)
endif()

if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
Expand Down
11 changes: 10 additions & 1 deletion mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,16 @@ else()
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()

if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
if(MLX_BUILD_ROCM)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp)
endif()

if(MLX_BUILD_METAL
OR MLX_BUILD_CUDA
OR MLX_BUILD_ROCM)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
Expand Down
85 changes: 85 additions & 0 deletions mlx/backend/rocm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Filename rules in ROCm backend:
#
# * Use .hip/.hpp if code contains device code, and .cpp/.h if not.
# * Device-only code should be put in device/ subdir.
# * Files in device/ subdir should not include files outside.
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/binary.hip
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.hip
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.hip
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.hip
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip
${CMAKE_CURRENT_SOURCE_DIR}/primitives.hip
${CMAKE_CURRENT_SOURCE_DIR}/random.hip
${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip
${CMAKE_CURRENT_SOURCE_DIR}/rope.hip
${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip
${CMAKE_CURRENT_SOURCE_DIR}/sort.hip
${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip
${CMAKE_CURRENT_SOURCE_DIR}/unary.hip
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)

target_compile_definitions(mlx PRIVATE MLX_USE_ROCM)

# Embed kernel sources in binary for JIT compilation.
file(
GLOB MLX_JIT_SOURCES
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.hpp")
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
add_custom_command(
OUTPUT gen/rocm_jit_sources.h
COMMAND
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
add_custom_target(rocm_jit_sources DEPENDS gen/rocm_jit_sources.h)
add_dependencies(mlx rocm_jit_sources)
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")

# Find ROCm installation
find_package(hip REQUIRED)
find_package(rocblas REQUIRED)

# Link with ROCm libraries
target_link_libraries(mlx PRIVATE hip::device roc::rocblas)

# Set GPU architectures for ROCm Common ROCm architectures: gfx900, gfx906,
# gfx908, gfx90a, gfx1030, gfx1100
set(MLX_ROCM_ARCHITECTURES
"gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1100"
CACHE STRING "ROCm GPU architectures")
message(STATUS "ROCm GPU architectures: ${MLX_ROCM_ARCHITECTURES}")

# Set GPU targets for HIP compilation
set_property(TARGET mlx PROPERTY HIP_ARCHITECTURES "${MLX_ROCM_ARCHITECTURES}")

# Enable HIP language support
enable_language(HIP)

# Set HIP compiler flags
target_compile_options(
mlx
PRIVATE "$<$<COMPILE_LANGUAGE:HIP>:-fgpu-rdc>"
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wall>"
"$<$<COMPILE_LANGUAGE:HIP>:-Xcompiler=-Wextra>")

# Add ROCm include directories
target_include_directories(mlx PRIVATE ${hip_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${rocblas_INCLUDE_DIRS})
206 changes: 206 additions & 0 deletions mlx/backend/rocm/allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/rocm/allocator.h"
#include "mlx/backend/rocm/utils.h"
#include "mlx/backend/rocm/worker.h"

#include <fmt/format.h>
#include <hip/hip_runtime.h>
#include <unistd.h>

#include <cassert>

namespace mlx::core {

namespace rocm {

RocmAllocator::RocmAllocator()
: buffer_cache_(
getpagesize(),
[](RocmBuffer* buf) { return buf->size; },
[this](RocmBuffer* buf) {
rocm_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_HIP_ERROR(hipMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_;
}

Buffer RocmAllocator::malloc(size_t size) {
// Find available buffer from cache.
std::unique_lock lock(mutex_);
RocmBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}

lock.unlock();
buf = new RocmBuffer{nullptr, size};
hipError_t err = hipMallocManaged(&buf->data, size);
if (err != hipSuccess && err != hipErrorMemoryAllocation) {
throw std::runtime_error(
fmt::format("hipMallocManaged failed: {}.", hipGetErrorString(err)));
}
lock.lock();
}
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);

// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}

return Buffer{buf};
}

void RocmAllocator::free(Buffer buffer) {
auto* buf = static_cast<RocmBuffer*>(buffer.ptr());
if (!buf) {
return;
}

std::unique_lock lock(mutex_);
active_memory_ -= buf->size;
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
lock.unlock();
rocm_free(buf->data);
delete buf;
}
}

size_t RocmAllocator::size(Buffer buffer) const {
auto* buf = static_cast<RocmBuffer*>(buffer.ptr());
if (!buf) {
return 0;
}
return buf->size;
}

void RocmAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}

void RocmAllocator::rocm_free(void* buf) {
// If rocm_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->rocm_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}

hipFree(buf);
}

size_t RocmAllocator::get_active_memory() const {
return active_memory_;
}

size_t RocmAllocator::get_peak_memory() const {
return peak_memory_;
}

void RocmAllocator::reset_peak_memory() {
std::lock_guard lock(mutex_);
peak_memory_ = 0;
}

size_t RocmAllocator::get_memory_limit() {
return memory_limit_;
}

size_t RocmAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_);
return limit;
}

size_t RocmAllocator::get_cache_memory() const {
return buffer_cache_.cache_size();
}

size_t RocmAllocator::set_cache_limit(size_t limit) {
std::lock_guard lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
}

void RocmAllocator::clear_cache() {
std::lock_guard lk(mutex_);
buffer_cache_.clear();
}

RocmAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of RocmAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static RocmAllocator* allocator_ = new RocmAllocator;
return *allocator_;
}

} // namespace rocm

namespace allocator {

Allocator& allocator() {
return rocm::allocator();
}

void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<rocm::RocmBuffer*>(ptr_)->data;
}

} // namespace allocator

size_t get_active_memory() {
return rocm::allocator().get_active_memory();
}
size_t get_peak_memory() {
return rocm::allocator().get_peak_memory();
}
void reset_peak_memory() {
return rocm::allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return rocm::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return rocm::allocator().get_memory_limit();
}
size_t get_cache_memory() {
return rocm::allocator().get_cache_memory();
}
size_t set_cache_limit(size_t limit) {
return rocm::allocator().set_cache_limit(limit);
}
void clear_cache() {
rocm::allocator().clear_cache();
}

// Not supported in ROCm.
size_t set_wired_limit(size_t) {
return 0;
}

} // namespace mlx::core
67 changes: 67 additions & 0 deletions mlx/backend/rocm/allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.

#pragma once

#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"

#include <mutex>
#include <set>
#include <thread>
#include <utility>

namespace mlx::core::rocm {

class Worker;

using allocator::Buffer;

// Stores ROCm-managed unified memory.
struct RocmBuffer {
void* data;
size_t size;
};

class RocmAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;

// Register current thread as safe to free buffers.
// In ROCm freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();

// Call hipFree in the safe thread.
void rocm_free(void* buf);

size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
size_t get_cache_memory() const;
size_t set_cache_limit(size_t limit);
void clear_cache();

private:
RocmAllocator();
friend RocmAllocator& allocator();

std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;

std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;
BufferCache<RocmBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
};

RocmAllocator& allocator();

} // namespace mlx::core::rocm
Loading