Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ python/triton/backends/
third_party/iluvatar/iluvatarTritonPlugin.so
third_party/triton_shared/
third_party/xpu/backend/xpu3
third_party/ascend

# Proton
python/triton/profiler
Expand Down
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads")
elseif(FLAGTREE_BACKEND STREQUAL "ascend")
set(CMAKE_C_COMPILER clang)
set(CMAKE_CXX_COMPILER clang++)
add_compile_options("-Wno-deprecated-declarations")
add_compile_options("-Wno-error=deprecated-declarations")
endif()
set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}")
if(FLAGTREE_PLUGIN)
Expand Down Expand Up @@ -476,7 +478,7 @@ endif()

add_subdirectory(third_party/f2reduce)

if(NOT FLAGTREE_BACKEND)
if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND MATCHES "^(aipu|ascend|tsingmicro)$")
add_subdirectory(bin)
add_subdirectory(test)
endif()
Expand Down
32 changes: 6 additions & 26 deletions include/flagtree/Common/UnifiedHardware.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,16 @@ namespace flagtree {
class UnifiedHardware {

public:
~UnifiedHardware() = default;
UnifiedHardware() = default;
#ifdef FLAGTREE_BACKEND
static bool registered;
int getDMATag();
int getSharedMemoryTag();
std::string getFlagTreeBackend() { return FLAGTREE_BACKEND; }
#else
static constexpr bool registered = false;
void *getDMATag() { return nullptr; }
void *getSharedMemoryTag() { return nullptr; }
std::string getFlagTreeBackend() { return "default"; }
#endif
virtual ~UnifiedHardware() = default;
virtual bool isRegistered() const;
virtual int getDMATag() const;
virtual int getSharedMemoryTag() const;
virtual std::string getReduceStrategy() const;
virtual std::string getFlagTreeBackend() const;
};

std::unique_ptr<UnifiedHardware> createUnifiedHardwareManager();

} // namespace flagtree
} // namespace mlir

#define SET_REGISTER_FLAG(_Ty, FLAG) bool _Ty::registered = FLAG;

#define FLAGTREE_REGISTRAR_GET(_Ty, _Fn, _VAL) \
decltype(_VAL) _Ty::get##_Fn() { return static_cast<decltype(_VAL)>(_VAL); }

#ifdef FLAGTREE_BACKEND
#define FLAGTREE_REGISTRAR(fn_name, _VAL) \
using UnifiedHardwareType = mlir::flagtree::UnifiedHardware; \
FLAGTREE_REGISTRAR_GET(UnifiedHardwareType, fn_name, _VAL) \
SET_REGISTER_FLAG(UnifiedHardwareType, true)
#else
#define FLAGTREE_REGISTRAR(...)
#endif
24 changes: 24 additions & 0 deletions include/triton/Conversion/MLIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,19 @@ inline Type u1Ty(MLIRContext *ctx) {
}

// Float types
#if LLVM_VERSION_MAJOR < 21
inline Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
#else // triton_v3.3.x
inline Type f16Ty(MLIRContext *ctx) { return Float16Type::get(ctx); }
inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); }
inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }
#endif

#if LLVM_VERSION_MAJOR < 21

inline bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
Expand All @@ -39,6 +48,21 @@ inline bool isFloat8(Type type) {
type.isFloat8E5M2FNUZ();
}

#else // triton_v3.3.x

inline bool isFloat8(Type type) {
return isa<Float8E4M3B11FNUZType, Float8E4M3FNType, Float8E4M3FNUZType,
Float8E5M2Type, Float8E5M2FNUZType>(type);
}

inline bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
type.isBF16() || llvm::isa<Float8E4M3B11FNUZType>(type) ||
isFloat8(type);
}

#endif

inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }

} // namespace type
Expand Down
15 changes: 15 additions & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,24 @@ bool supportMMA(triton::DotOp op, int version) {
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
#if LLVM_VERSION_MAJOR < 21
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
#else // triton_v3.3.x
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
#endif
return false;
}
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
if (op.getMaxNumImpreciseAcc() < 32 &&
#if LLVM_VERSION_MAJOR < 21
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
#else // triton_v3.3.x
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
#endif
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
Expand All @@ -529,8 +539,13 @@ bool supportMMA(Value value, int version) {
auto elemTy = cast<TensorOrMemDesc>(value.getType()).getElementType();
// FP8 is not natively supported on all mma versions but it can always be
// promoted to fp16 therefore we can always support it.
#if LLVM_VERSION_MAJOR < 21
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
#else // triton_v3.3.x
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType>(elemTy);
#endif
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ LogicalResult UpcastMXFPOp::verify() {
auto xTy = getSrc().getType();
auto scaleTy = getScale().getType();

#if LLVM_VERSION_MAJOR < 21
if (xTy.getElementType() != FloatType::getBF16(getContext()) &&
#else // triton_v3.3.x
if (xTy.getElementType() != BFloat16Type::get(getContext()) &&
#endif
xTy.getElementType() != IntegerType::get(getContext(), 8)) {
return emitOpError("element type of the first operand must be bf16 or i8");
}
Expand Down Expand Up @@ -97,7 +101,11 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
auto newShape = SmallVector<int64_t>(xShape);
newShape.back() *= 2;
inferredReturnTypes.push_back(
#if LLVM_VERSION_MAJOR < 21
RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding));
#else
RankedTensorType::get(newShape, BFloat16Type::get(ctx), newVEncoding));
#endif
} else {
inferredReturnTypes.push_back(xTy);
}
Expand Down
14 changes: 13 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,11 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
NvidiaMmaEncodingAttr mmaLayout =
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
if (mmaLayout) {
#if LLVM_VERSION_MAJOR < 21
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
#else // triton_v3.3.x
bool isNativeFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
#endif
// promote operands for sm < 89 since fp8 mma is not natively supported
// promote operands for sm >= 90 when mma is not v3
if (!isNativeFP8 ||
Expand Down Expand Up @@ -422,12 +426,20 @@ class ScaledBlockedToMMAv2
auto aType = dotOp.getLhsType();
auto bType = dotOp.getRhsType();

auto enumToType = [&rewriter](F8F6F4Type type) {
auto enumToType = [&rewriter](F8F6F4Type type) -> Type {
switch (type) {
case F8F6F4Type::E4M3:
#if LLVM_VERSION_MAJOR < 21
return rewriter.getFloat8E4M3FNType();
#else // triton_v3.3.x
return Float8E4M3FNType::get(rewriter.getContext());
#endif
case F8F6F4Type::E5M2:
#if LLVM_VERSION_MAJOR < 21
return rewriter.getFloat8E5M2Type();
#else // triton_v3.3.x
return Float8E5M2Type::get(rewriter.getContext());
#endif
default:
llvm_unreachable("unexpected type");
}
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
SmallVector<unsigned> validN;

// MMAv3 with larger instruction shape is preferred.
#if LLVM_VERSION_MAJOR < 21
if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
eltType.isF32()) {
#else // triton_v3.3.x
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3FNUZType>(
eltType) ||
eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
#endif
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,13 @@ bool WarpGroupDotOp::needsPartialAccumulator() {
const auto &d = getD();
auto aTensorTy = cast<TensorOrMemDesc>(a.getType());
auto aElTy = cast<TensorOrMemDesc>(a.getType()).getElementType();
#if LLVM_VERSION_MAJOR < 21
bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() ||
aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ();
#else // triton_v3.3.x
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType>(aElTy);
#endif
bool accFP32 = cast<TensorOrMemDesc>(d.getType()).getElementType().isF32();
uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc();
return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1];
Expand Down
21 changes: 20 additions & 1 deletion lib/flagtree/Common/UnifiedHardware.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,26 @@
namespace mlir {
namespace flagtree {

std::unique_ptr<UnifiedHardware> createUnifiedHardwareManager() {
bool UnifiedHardware::isRegistered() const {
#ifdef FLAGTREE_BACKEND
return true;
#else
return false;
#endif
}

int UnifiedHardware::getDMATag() const { return 0; }

int UnifiedHardware::getSharedMemoryTag() const { return 0; }

std::string UnifiedHardware::getReduceStrategy() const {
return "linalg_reduce";
}

std::string UnifiedHardware::getFlagTreeBackend() const { return "default"; }

__attribute__((weak)) std::unique_ptr<UnifiedHardware>
createUnifiedHardwareManager() {
return std::make_unique<UnifiedHardware>();
}

Expand Down
8 changes: 7 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,13 @@ def get_platform_dependent_src_path(subdir):
(*version.split('.'))))

if helper.flagtree_backend:
backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()]
if helper.flagtree_backend in ("ascend"):
backends = [
*BackendInstaller.copy(helper.default_backends + helper.extend_backends),
*BackendInstaller.copy_externals(),
]
else:
backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()]
else:
backends = [*BackendInstaller.copy(helper.default_backends), *BackendInstaller.copy_externals()]

Expand Down
9 changes: 6 additions & 3 deletions python/setup_tools/setup_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,12 @@ def check_env(env_val):

download_flagtree_third_party("triton_shared", condition=(not flagtree_backend))

download_flagtree_third_party("ascend", condition=(flagtree_backend == "ascend"), hock=utils.ascend.precompile_hock,
download_flagtree_third_party("flir", condition=(flagtree_backend == "ascend"), hock=utils.ascend.precompile_hook_flir,
required=True)

#download_flagtree_third_party("ascend", condition=(flagtree_backend == "ascend"), hock=utils.ascend.precompile_hook,
# required=True)

handle_flagtree_backend()

cache = FlagTreeCache()
Expand Down Expand Up @@ -387,9 +390,9 @@ def check_env(env_val):

# ascend
cache.store(
file="llvm-b5cc222d-ubuntu-arm64",
file="llvm-a66376b0-ubuntu-arm64",
condition=("ascend" == flagtree_backend),
url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz",
url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-a66376b0-ubuntu-arm64.tar.gz",
pre_hock=lambda: check_env('LLVM_SYSPATH'),
post_hock=set_llvm_env,
)
Expand Down
10 changes: 7 additions & 3 deletions python/setup_tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
tools.Module(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
commit_id="380b87122c88af131530903a702d5318ec59bb33",
dst_path=os.path.join(flagtree_submodule_dir, "triton_shared")),
"ascend":
tools.Module(name="ascend", url="https://gitcode.com/FlagTree/triton-ascend.git",
dst_path=os.path.join(flagtree_submodule_dir, "triton_ascend")),
"flir":
tools.Module(name="flir", url="https://github.com/FlagTree/flir.git",
dst_path=os.path.join(flagtree_submodule_dir, "flir")),
#"ascend":
#tools.Module(name="ascend", url="https://gitcode.com/FlagTree/triton-ascend.git",
# commit_id="18803572c3aaf55b914090560fe8a31bc5eaa2cc", # ascend_with_llvma66376b0_20251021_debug
# dst_path=os.path.join(flagtree_submodule_dir, "triton_ascend")),
}


Expand Down
23 changes: 15 additions & 8 deletions python/setup_tools/utils/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
from pathlib import Path
from setup_tools.utils.tools import flagtree_root_dir, Module, flagtree_submodule_dir, DownloadManager

def precompile_hook_flir(*args, **kargs):
default_backends = kargs["default_backends"]
if 'amd' in default_backends:
default_backends.remove('amd')
default_backends.append('flir')

downloader = DownloadManager()

submodules = (Module(name="ascendnpu-ir", url="https://gitee.com/ascend/ascendnpu-ir.git",
commit_id="1922371c42749fda534d6395b7ed828b5c9f36d4",
dst_path=os.path.join(flagtree_submodule_dir, "ascend/third_party/ascendnpu-ir")), )


'''
def get_backend_cmake_args(*args, **kargs):
build_ext = kargs['build_ext']
src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt")
Expand All @@ -24,8 +30,8 @@ def install_extension(*args, **kargs):
python_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
dst_ext_path = os.path.join(python_root_dir, "triton/backends/ascend/triton-adapter-opt")
shutil.copy(src_ext_path, dst_ext_path)


'''
'''
def create_symlink_for_triton(link_map):
for target, source in link_map.items():
target_path = Path(os.path.join(flagtree_root_dir, "python", target))
Expand Down Expand Up @@ -91,22 +97,22 @@ def get_package_dir():
create_symlink_for_triton(package_dict)
raise RuntimeError("will Fixed")
return package_dict


'''
'''
def get_extra_install_packages():
return [
"triton/triton_patch",
"triton/triton_patch/language",
"triton/triton_patch/compiler",
"triton/triton_patch/runtime",
]

'''

def is_compile_ascend_npu_ir():
return os.getenv("ASCEND_NPU_IR_COMPILE", "1") == "1"


def precompile_hock(*args, **kargs):
'''
def precompile_hook(*args, **kargs):
third_party_base_dir = Path(kargs['third_party_base_dir'])
ascend_path = Path(third_party_base_dir) / "ascend"
patch_path = Path(ascend_path) / "triton_patch"
Expand Down Expand Up @@ -150,3 +156,4 @@ def precompile_hock(*args, **kargs):
except Exception as e:
print(f"[ERROR]: Unknown error: {str(e)}")
return False
'''
Loading