Skip to content

MooreThreads/torchada

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

logo

torchada

English | 中文

Run your CUDA code on Moore Threads GPUs — zero code changes required

torchada is an adapter that makes torch_musa (Moore Threads GPU support for PyTorch) compatible with standard PyTorch CUDA APIs. Import it once, and your existing torch.cuda.* code works on MUSA hardware.

Why torchada?

Many PyTorch projects are written for NVIDIA GPUs using torch.cuda.* APIs. To run these on Moore Threads GPUs, you would normally need to change every cuda reference to musa. torchada eliminates this by automatically translating CUDA API calls to MUSA equivalents at runtime.

Prerequisites

  • torch_musa: You must have torch_musa installed (this provides MUSA support for PyTorch)
  • Moore Threads GPU: A Moore Threads GPU with proper driver installed

Installation

pip install torchada

# Or install from source
git clone https://github.com/MooreThreads/torchada.git
cd torchada
pip install -e .

Quick Start

import torchada  # ← Add this one line at the top
import torch

# Your existing CUDA code works unchanged:
x = torch.randn(10, 10).cuda()
print(torch.cuda.device_count())
torch.cuda.synchronize()

That's it! All torch.cuda.* APIs are automatically redirected to torch.musa.*.

What Works

Feature Example
Device operations tensor.cuda(), model.cuda(), torch.device("cuda")
Memory management torch.cuda.memory_allocated(), empty_cache()
Synchronization torch.cuda.synchronize(), Stream, Event
Mixed precision torch.cuda.amp.autocast(), GradScaler()
CUDA Graphs torch.cuda.CUDAGraph, torch.cuda.graph()
Distributed dist.init_process_group(backend='nccl') → uses MCCL
torch.compile torch.compile(model) with all backends
C++ Extensions CUDAExtension, BuildExtension, load()

Examples

Mixed Precision Training

import torchada
import torch

model = MyModel().cuda()
scaler = torch.cuda.amp.GradScaler()

with torch.cuda.amp.autocast():
    output = model(data.cuda())
    loss = criterion(output, target.cuda())

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Distributed Training

import torchada
import torch.distributed as dist

# 'nccl' is automatically mapped to 'mccl' on MUSA
dist.init_process_group(backend='nccl')

CUDA Graphs

import torchada
import torch

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph=g):  # cuda_graph= keyword works on MUSA
    y = model(x)

torch.compile

import torchada
import torch

compiled_model = torch.compile(model.cuda(), backend='inductor')

Building C++ Extensions

import torchada  # Must import before torch.utils.cpp_extension
from torch.utils.cpp_extension import CUDAExtension, BuildExtension

# Standard CUDAExtension works — torchada handles CUDA→MUSA translation
ext = CUDAExtension("my_ext", sources=["kernel.cu"])

Platform Detection

import torchada
from torchada import detect_platform, Platform

platform = detect_platform()
if platform == Platform.MUSA:
    print("Running on Moore Threads GPU")
elif platform == Platform.CUDA:
    print("Running on NVIDIA GPU")

# Or use torch.version-based detection
def is_musa():
    import torch
    return hasattr(torch.version, 'musa') and torch.version.musa is not None

Known Limitation

Device type string comparisons fail on MUSA:

device = torch.device("cuda:0")  # On MUSA, this becomes musa:0
device.type == "cuda"  # Returns False!

Solution: Use torchada.is_gpu_device():

import torchada

if torchada.is_gpu_device(device):  # Works on both CUDA and MUSA
    ...
# Or: device.type in ("cuda", "musa")

API Reference

Function Description
detect_platform() Returns Platform.CUDA, Platform.MUSA, or Platform.CPU
is_musa_platform() Returns True if running on MUSA
is_cuda_platform() Returns True if running on CUDA
is_gpu_device(device) Returns True if device is CUDA or MUSA
CUDA_HOME Path to CUDA/MUSA installation

Note: torch.cuda.is_available() is intentionally NOT redirected — it returns False on MUSA. This allows proper platform detection. Use torch.musa.is_available() or is_musa() function instead.

C++ Extension Symbol Mapping

When building C++ extensions, torchada automatically translates CUDA symbols to MUSA:

CUDA MUSA
cudaMalloc musaMalloc
cudaStream_t musaStream_t
cublasHandle_t mublasHandle_t
at::cuda at::musa
c10::cuda c10::musa
#include <cuda/*> #include <musa/*>

See src/torchada/_mapping.py for the complete mapping table (380+ mappings).

Integrating torchada into Your Project

Step 1: Add Dependency

# pyproject.toml or requirements.txt
torchada>=0.1.16

Step 2: Conditional Import

# At your application entry point
def is_musa():
    import torch
    return hasattr(torch.version, "musa") and torch.version.musa is not None

if is_musa():
    import torchada  # noqa: F401

# Rest of your code uses torch.cuda.* as normal

Step 3: Extend Feature Flags (if applicable)

# Include MUSA in GPU capability checks
if is_nvidia() or is_musa():
    ENABLE_FLASH_ATTENTION = True

Step 4: Fix Device Type Checks (if applicable)

# Instead of: device.type == "cuda"
# Use: device.type in ("cuda", "musa")
# Or: torchada.is_gpu_device(device)

Projects Using torchada

Project Category Status
Xinference Model Serving ✅ Merged
LightLLM Model Serving ✅ Merged
LightX2V Image/Video Generation ✅ Merged
SGLang Model Serving In Progress
ComfyUI Image/Video Generation In Progress

License

MIT License

About

Adapter package for torch_musa to act exactly like PyTorch CUDA

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published