English | 中文
Run your CUDA code on Moore Threads GPUs — zero code changes required
torchada is an adapter that makes torch_musa (Moore Threads GPU support for PyTorch) compatible with standard PyTorch CUDA APIs. Import it once, and your existing torch.cuda.* code works on MUSA hardware.
Many PyTorch projects are written for NVIDIA GPUs using torch.cuda.* APIs. To run these on Moore Threads GPUs, you would normally need to change every cuda reference to musa. torchada eliminates this by automatically translating CUDA API calls to MUSA equivalents at runtime.
- torch_musa: You must have torch_musa installed (this provides MUSA support for PyTorch)
- Moore Threads GPU: A Moore Threads GPU with proper driver installed
pip install torchada
# Or install from source
git clone https://github.com/MooreThreads/torchada.git
cd torchada
pip install -e .import torchada # ← Add this one line at the top
import torch
# Your existing CUDA code works unchanged:
x = torch.randn(10, 10).cuda()
print(torch.cuda.device_count())
torch.cuda.synchronize()That's it! All torch.cuda.* APIs are automatically redirected to torch.musa.*.
| Feature | Example |
|---|---|
| Device operations | tensor.cuda(), model.cuda(), torch.device("cuda") |
| Memory management | torch.cuda.memory_allocated(), empty_cache() |
| Synchronization | torch.cuda.synchronize(), Stream, Event |
| Mixed precision | torch.cuda.amp.autocast(), GradScaler() |
| CUDA Graphs | torch.cuda.CUDAGraph, torch.cuda.graph() |
| Distributed | dist.init_process_group(backend='nccl') → uses MCCL |
| torch.compile | torch.compile(model) with all backends |
| C++ Extensions | CUDAExtension, BuildExtension, load() |
import torchada
import torch
model = MyModel().cuda()
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
output = model(data.cuda())
loss = criterion(output, target.cuda())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()import torchada
import torch.distributed as dist
# 'nccl' is automatically mapped to 'mccl' on MUSA
dist.init_process_group(backend='nccl')import torchada
import torch
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph=g): # cuda_graph= keyword works on MUSA
y = model(x)import torchada
import torch
compiled_model = torch.compile(model.cuda(), backend='inductor')import torchada # Must import before torch.utils.cpp_extension
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
# Standard CUDAExtension works — torchada handles CUDA→MUSA translation
ext = CUDAExtension("my_ext", sources=["kernel.cu"])import torchada
from torchada import detect_platform, Platform
platform = detect_platform()
if platform == Platform.MUSA:
print("Running on Moore Threads GPU")
elif platform == Platform.CUDA:
print("Running on NVIDIA GPU")
# Or use torch.version-based detection
def is_musa():
import torch
return hasattr(torch.version, 'musa') and torch.version.musa is not NoneDevice type string comparisons fail on MUSA:
device = torch.device("cuda:0") # On MUSA, this becomes musa:0
device.type == "cuda" # Returns False!Solution: Use torchada.is_gpu_device():
import torchada
if torchada.is_gpu_device(device): # Works on both CUDA and MUSA
...
# Or: device.type in ("cuda", "musa")| Function | Description |
|---|---|
detect_platform() |
Returns Platform.CUDA, Platform.MUSA, or Platform.CPU |
is_musa_platform() |
Returns True if running on MUSA |
is_cuda_platform() |
Returns True if running on CUDA |
is_gpu_device(device) |
Returns True if device is CUDA or MUSA |
CUDA_HOME |
Path to CUDA/MUSA installation |
Note: torch.cuda.is_available() is intentionally NOT redirected — it returns False on MUSA. This allows proper platform detection. Use torch.musa.is_available() or is_musa() function instead.
When building C++ extensions, torchada automatically translates CUDA symbols to MUSA:
| CUDA | MUSA |
|---|---|
cudaMalloc |
musaMalloc |
cudaStream_t |
musaStream_t |
cublasHandle_t |
mublasHandle_t |
at::cuda |
at::musa |
c10::cuda |
c10::musa |
#include <cuda/*> |
#include <musa/*> |
See src/torchada/_mapping.py for the complete mapping table (380+ mappings).
# pyproject.toml or requirements.txt
torchada>=0.1.16
# At your application entry point
def is_musa():
import torch
return hasattr(torch.version, "musa") and torch.version.musa is not None
if is_musa():
import torchada # noqa: F401
# Rest of your code uses torch.cuda.* as normal# Include MUSA in GPU capability checks
if is_nvidia() or is_musa():
ENABLE_FLASH_ATTENTION = True# Instead of: device.type == "cuda"
# Use: device.type in ("cuda", "musa")
# Or: torchada.is_gpu_device(device)| Project | Category | Status |
|---|---|---|
| Xinference | Model Serving | ✅ Merged |
| LightLLM | Model Serving | ✅ Merged |
| LightX2V | Image/Video Generation | ✅ Merged |
| SGLang | Model Serving | In Progress |
| ComfyUI | Image/Video Generation | In Progress |
MIT License
