diff --git a/.gitignore b/.gitignore index 69f96e552..9483742b3 100644 --- a/.gitignore +++ b/.gitignore @@ -22,7 +22,6 @@ python/triton/backends/ third_party/iluvatar/iluvatarTritonPlugin.so third_party/triton_shared/ third_party/xpu/backend/xpu3 -third_party/ascend # Proton python/triton/profiler diff --git a/CMakeLists.txt b/CMakeLists.txt index bee80b3c3..e2008ebfa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,8 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads") elseif(FLAGTREE_BACKEND STREQUAL "ascend") set(CMAKE_C_COMPILER clang) set(CMAKE_CXX_COMPILER clang++) + add_compile_options("-Wno-deprecated-declarations") + add_compile_options("-Wno-error=deprecated-declarations") endif() set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}") if(FLAGTREE_PLUGIN) @@ -476,7 +478,7 @@ endif() add_subdirectory(third_party/f2reduce) -if(NOT FLAGTREE_BACKEND) +if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND MATCHES "^(aipu|ascend|tsingmicro)$") add_subdirectory(bin) add_subdirectory(test) endif() diff --git a/include/flagtree/Common/UnifiedHardware.h b/include/flagtree/Common/UnifiedHardware.h index 3c821e783..5ec1c241c 100644 --- a/include/flagtree/Common/UnifiedHardware.h +++ b/include/flagtree/Common/UnifiedHardware.h @@ -15,36 +15,16 @@ namespace flagtree { class UnifiedHardware { public: - ~UnifiedHardware() = default; UnifiedHardware() = default; -#ifdef FLAGTREE_BACKEND - static bool registered; - int getDMATag(); - int getSharedMemoryTag(); - std::string getFlagTreeBackend() { return FLAGTREE_BACKEND; } -#else - static constexpr bool registered = false; - void *getDMATag() { return nullptr; } - void *getSharedMemoryTag() { return nullptr; } - std::string getFlagTreeBackend() { return "default"; } -#endif + virtual ~UnifiedHardware() = default; + virtual bool isRegistered() const; + virtual int getDMATag() const; + virtual int getSharedMemoryTag() const; + virtual std::string getReduceStrategy() const; + virtual std::string getFlagTreeBackend() const; }; std::unique_ptr createUnifiedHardwareManager(); } // namespace flagtree } // namespace mlir - -#define SET_REGISTER_FLAG(_Ty, FLAG) bool _Ty::registered = FLAG; - -#define FLAGTREE_REGISTRAR_GET(_Ty, _Fn, _VAL) \ - decltype(_VAL) _Ty::get##_Fn() { return static_cast(_VAL); } - -#ifdef FLAGTREE_BACKEND -#define FLAGTREE_REGISTRAR(fn_name, _VAL) \ - using UnifiedHardwareType = mlir::flagtree::UnifiedHardware; \ - FLAGTREE_REGISTRAR_GET(UnifiedHardwareType, fn_name, _VAL) \ - SET_REGISTER_FLAG(UnifiedHardwareType, true) -#else -#define FLAGTREE_REGISTRAR(...) -#endif diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index a00f9f844..6163939ed 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -21,10 +21,19 @@ inline Type u1Ty(MLIRContext *ctx) { } // Float types +#if LLVM_VERSION_MAJOR < 21 inline Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); } inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); } inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); } inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); } +#else // triton_v3.3.x +inline Type f16Ty(MLIRContext *ctx) { return Float16Type::get(ctx); } +inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); } +inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); } +inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); } +#endif + +#if LLVM_VERSION_MAJOR < 21 inline bool isFloat(Type type) { return type.isF32() || type.isF64() || type.isF16() || type.isF128() || @@ -39,6 +48,21 @@ inline bool isFloat8(Type type) { type.isFloat8E5M2FNUZ(); } +#else // triton_v3.3.x + +inline bool isFloat8(Type type) { + return isa(type); +} + +inline bool isFloat(Type type) { + return type.isF32() || type.isF64() || type.isF16() || type.isF128() || + type.isBF16() || llvm::isa(type) || + isFloat8(type); +} + +#endif + inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } } // namespace type diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 4915d7b1a..00a355c5e 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -502,14 +502,24 @@ bool supportMMA(triton::DotOp op, int version) { return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && +#if LLVM_VERSION_MAJOR < 21 (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { +#else // triton_v3.3.x + (llvm::isa(aElemTy) || + aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || + aElemTy.isF32()))) { +#endif return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && +#if LLVM_VERSION_MAJOR < 21 (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && +#else // triton_v3.3.x + (llvm::isa(aElemTy)) && +#endif cast(op.getType()).getElementType().isF32()) { return false; } @@ -529,8 +539,13 @@ bool supportMMA(Value value, int version) { auto elemTy = cast(value.getType()).getElementType(); // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. +#if LLVM_VERSION_MAJOR < 21 bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); +#else // triton_v3.3.x + bool isFP8 = llvm::isa(elemTy); +#endif return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index e61fe096e..a5a101bcf 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -17,7 +17,11 @@ LogicalResult UpcastMXFPOp::verify() { auto xTy = getSrc().getType(); auto scaleTy = getScale().getType(); +#if LLVM_VERSION_MAJOR < 21 if (xTy.getElementType() != FloatType::getBF16(getContext()) && +#else // triton_v3.3.x + if (xTy.getElementType() != BFloat16Type::get(getContext()) && +#endif xTy.getElementType() != IntegerType::get(getContext(), 8)) { return emitOpError("element type of the first operand must be bf16 or i8"); } @@ -97,7 +101,11 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( auto newShape = SmallVector(xShape); newShape.back() *= 2; inferredReturnTypes.push_back( +#if LLVM_VERSION_MAJOR < 21 RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding)); +#else + RankedTensorType::get(newShape, BFloat16Type::get(ctx), newVEncoding)); +#endif } else { inferredReturnTypes.push_back(xTy); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index e26118bdf..12734a5db 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -368,7 +368,11 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { +#if LLVM_VERSION_MAJOR < 21 bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); +#else // triton_v3.3.x + bool isNativeFP8 = llvm::isa(AElType); +#endif // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || @@ -422,12 +426,20 @@ class ScaledBlockedToMMAv2 auto aType = dotOp.getLhsType(); auto bType = dotOp.getRhsType(); - auto enumToType = [&rewriter](F8F6F4Type type) { + auto enumToType = [&rewriter](F8F6F4Type type) -> Type { switch (type) { case F8F6F4Type::E4M3: +#if LLVM_VERSION_MAJOR < 21 return rewriter.getFloat8E4M3FNType(); +#else // triton_v3.3.x + return Float8E4M3FNType::get(rewriter.getContext()); +#endif case F8F6F4Type::E5M2: +#if LLVM_VERSION_MAJOR < 21 return rewriter.getFloat8E5M2Type(); +#else // triton_v3.3.x + return Float8E5M2Type::get(rewriter.getContext()); +#endif default: llvm_unreachable("unexpected type"); } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 4ef9d1cd1..277c3b5bc 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -43,9 +43,15 @@ SmallVector mmaVersionToInstrShape(int version, SmallVector validN; // MMAv3 with larger instruction shape is preferred. +#if LLVM_VERSION_MAJOR < 21 if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || eltType.isF32()) { +#else // triton_v3.3.x + if (llvm::isa( + eltType) || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { +#endif validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 088ca2663..e7c5a804f 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -75,8 +75,13 @@ bool WarpGroupDotOp::needsPartialAccumulator() { const auto &d = getD(); auto aTensorTy = cast(a.getType()); auto aElTy = cast(a.getType()).getElementType(); +#if LLVM_VERSION_MAJOR < 21 bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() || aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ(); +#else // triton_v3.3.x + bool isFP8 = llvm::isa(aElTy); +#endif bool accFP32 = cast(d.getType()).getElementType().isF32(); uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1]; diff --git a/lib/flagtree/Common/UnifiedHardware.cc b/lib/flagtree/Common/UnifiedHardware.cc index b09eca314..c9449ad49 100644 --- a/lib/flagtree/Common/UnifiedHardware.cc +++ b/lib/flagtree/Common/UnifiedHardware.cc @@ -3,7 +3,26 @@ namespace mlir { namespace flagtree { -std::unique_ptr createUnifiedHardwareManager() { +bool UnifiedHardware::isRegistered() const { +#ifdef FLAGTREE_BACKEND + return true; +#else + return false; +#endif +} + +int UnifiedHardware::getDMATag() const { return 0; } + +int UnifiedHardware::getSharedMemoryTag() const { return 0; } + +std::string UnifiedHardware::getReduceStrategy() const { + return "linalg_reduce"; +} + +std::string UnifiedHardware::getFlagTreeBackend() const { return "default"; } + +__attribute__((weak)) std::unique_ptr +createUnifiedHardwareManager() { return std::make_unique(); } diff --git a/python/setup.py b/python/setup.py index 6ca2fd8ac..17af85d08 100644 --- a/python/setup.py +++ b/python/setup.py @@ -566,7 +566,13 @@ def get_platform_dependent_src_path(subdir): (*version.split('.')))) if helper.flagtree_backend: - backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()] + if helper.flagtree_backend in ("ascend"): + backends = [ + *BackendInstaller.copy(helper.default_backends + helper.extend_backends), + *BackendInstaller.copy_externals(), + ] + else: + backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()] else: backends = [*BackendInstaller.copy(helper.default_backends), *BackendInstaller.copy_externals()] diff --git a/python/setup_tools/setup_helper.py b/python/setup_tools/setup_helper.py index 3fe0a49d6..2bb2146e9 100644 --- a/python/setup_tools/setup_helper.py +++ b/python/setup_tools/setup_helper.py @@ -325,9 +325,12 @@ def check_env(env_val): download_flagtree_third_party("triton_shared", condition=(not flagtree_backend)) -download_flagtree_third_party("ascend", condition=(flagtree_backend == "ascend"), hock=utils.ascend.precompile_hock, +download_flagtree_third_party("flir", condition=(flagtree_backend == "ascend"), hock=utils.ascend.precompile_hook_flir, required=True) +#download_flagtree_third_party("ascend", condition=(flagtree_backend == "ascend"), hock=utils.ascend.precompile_hook, +# required=True) + handle_flagtree_backend() cache = FlagTreeCache() @@ -387,9 +390,9 @@ def check_env(env_val): # ascend cache.store( - file="llvm-b5cc222d-ubuntu-arm64", + file="llvm-a66376b0-ubuntu-arm64", condition=("ascend" == flagtree_backend), - url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz", + url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-a66376b0-ubuntu-arm64.tar.gz", pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py index 6bd72172f..01b160aec 100644 --- a/python/setup_tools/utils/__init__.py +++ b/python/setup_tools/utils/__init__.py @@ -9,9 +9,13 @@ tools.Module(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", commit_id="380b87122c88af131530903a702d5318ec59bb33", dst_path=os.path.join(flagtree_submodule_dir, "triton_shared")), - "ascend": - tools.Module(name="ascend", url="https://gitcode.com/FlagTree/triton-ascend.git", - dst_path=os.path.join(flagtree_submodule_dir, "triton_ascend")), + "flir": + tools.Module(name="flir", url="https://github.com/FlagTree/flir.git", + dst_path=os.path.join(flagtree_submodule_dir, "flir")), + #"ascend": + #tools.Module(name="ascend", url="https://gitcode.com/FlagTree/triton-ascend.git", + # commit_id="18803572c3aaf55b914090560fe8a31bc5eaa2cc", # ascend_with_llvma66376b0_20251021_debug + # dst_path=os.path.join(flagtree_submodule_dir, "triton_ascend")), } diff --git a/python/setup_tools/utils/ascend.py b/python/setup_tools/utils/ascend.py index d28803721..4ad382cb1 100644 --- a/python/setup_tools/utils/ascend.py +++ b/python/setup_tools/utils/ascend.py @@ -3,13 +3,19 @@ from pathlib import Path from setup_tools.utils.tools import flagtree_root_dir, Module, flagtree_submodule_dir, DownloadManager +def precompile_hook_flir(*args, **kargs): + default_backends = kargs["default_backends"] + if 'amd' in default_backends: + default_backends.remove('amd') + default_backends.append('flir') + downloader = DownloadManager() submodules = (Module(name="ascendnpu-ir", url="https://gitee.com/ascend/ascendnpu-ir.git", commit_id="1922371c42749fda534d6395b7ed828b5c9f36d4", dst_path=os.path.join(flagtree_submodule_dir, "ascend/third_party/ascendnpu-ir")), ) - +''' def get_backend_cmake_args(*args, **kargs): build_ext = kargs['build_ext'] src_ext_path = build_ext.get_ext_fullpath("triton-adapter-opt") @@ -24,8 +30,8 @@ def install_extension(*args, **kargs): python_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) dst_ext_path = os.path.join(python_root_dir, "triton/backends/ascend/triton-adapter-opt") shutil.copy(src_ext_path, dst_ext_path) - - +''' +''' def create_symlink_for_triton(link_map): for target, source in link_map.items(): target_path = Path(os.path.join(flagtree_root_dir, "python", target)) @@ -91,8 +97,8 @@ def get_package_dir(): create_symlink_for_triton(package_dict) raise RuntimeError("will Fixed") return package_dict - - +''' +''' def get_extra_install_packages(): return [ "triton/triton_patch", @@ -100,13 +106,13 @@ def get_extra_install_packages(): "triton/triton_patch/compiler", "triton/triton_patch/runtime", ] - +''' def is_compile_ascend_npu_ir(): return os.getenv("ASCEND_NPU_IR_COMPILE", "1") == "1" - -def precompile_hock(*args, **kargs): +''' +def precompile_hook(*args, **kargs): third_party_base_dir = Path(kargs['third_party_base_dir']) ascend_path = Path(third_party_base_dir) / "ascend" patch_path = Path(ascend_path) / "triton_patch" @@ -150,3 +156,4 @@ def precompile_hock(*args, **kargs): except Exception as e: print(f"[ERROR]: Unknown error: {str(e)}") return False +''' diff --git a/python/src/passes.cc b/python/src/passes.cc index 37bb392da..de1d21030 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -81,7 +81,11 @@ void init_triton_passes_ttgpuir(py::module &&m) { void init_triton_passes_convert(py::module &&m) { using namespace mlir; +#if LLVM_VERSION_MAJOR < 21 ADD_PASS_WRAPPER_0("add_scf_to_cf", createConvertSCFToCFPass); +#else // triton_v3.3.x + ADD_PASS_WRAPPER_0("add_scf_to_cf", createSCFToControlFlowPass); +#endif ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); diff --git a/python/test/ops/01_vector_add/01_vector_add.py b/python/test/ops/01_vector_add/01_vector_add.py new file mode 100644 index 000000000..accc56cd9 --- /dev/null +++ b/python/test/ops/01_vector_add/01_vector_add.py @@ -0,0 +1,110 @@ +""" +Vector Addition +=============== + +In this tutorial, you will write a simple vector addition using Triton. + +In doing so, you will learn about: + +* The basic programming model of Triton. + +* The `triton.jit` decorator, which is used to define Triton kernels. + +* The best practices for validating and benchmarking your custom ops against native reference implementations. + +""" + +# %% +# Compute Kernel +# -------------- + +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +# %% +# Let's also declare a helper function to (1) allocate the `z` tensor +# and (2) enqueue the above kernel with appropriate grid/block sizes: + + +def add(x: torch.Tensor, y: torch.Tensor): + # We need to preallocate the output. + output = torch.empty_like(x) + n_elements = output.numel() + # The SPMD launch grid denotes the number of kernel instances that run in parallel. + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. + # In this case, we use a 1D grid where the size is the number of blocks: + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + with torch_device_fn.device(x.device): + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return output.cpu() + + +# %% +# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects +# and test its correctness: + + +def test_vector_add(): + torch.manual_seed(0) + size = 4432 + x = torch.rand(size, device=device) + y = torch.rand(size, device=device) + output_torch = x.cpu() + y.cpu() + + output_triton = add(x, y) + print( + f"The maximum difference between torch and triton is " + f"{torch.max(torch.abs(output_torch - output_triton))}" + ) + assert torch.allclose(output_triton, output_torch), (output_triton, output_torch) + + +if __name__ == "__main__": + test_vector_add() diff --git a/python/test/ops/01_vector_add/triton-ascend/add_kernel.ttadapter b/python/test/ops/01_vector_add/triton-ascend/add_kernel.ttadapter new file mode 100644 index 000000000..5d3bd6bce --- /dev/null +++ b/python/test/ops/01_vector_add/triton-ascend/add_kernel.ttadapter @@ -0,0 +1,40 @@ +module { + func.func @add_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %cst = arith.constant 0.000000e+00 : f32 + %c1024 = arith.constant 1024 : index + %c1024_i32 = arith.constant 1024 : i32 + %0 = arith.muli %arg9, %c1024_i32 : i32 + %1 = arith.index_cast %0 : i32 to index + %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%1], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1], offset: ?>> + %alloc = memref.alloc() : memref<1024xf32> + %2 = arith.addi %1, %c1024 : index + %3 = arith.index_cast %arg5 : i32 to index + %4 = arith.maxsi %1, %3 : index + %5 = arith.minsi %2, %4 : index + %6 = arith.subi %5, %1 : index + %7 = arith.cmpi slt, %6, %c1024 : index + scf.if %7 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<1024xf32>) + } + %subview = memref.subview %reinterpret_cast[0] [%6] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> + %subview_0 = memref.subview %alloc[0] [%6] [1] : memref<1024xf32> to memref> + memref.copy %subview, %subview_0 : memref> to memref> + %8 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %arg3 to offset: [%1], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1], offset: ?>> + %alloc_2 = memref.alloc() : memref<1024xf32> + scf.if %7 { + linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<1024xf32>) + } + %subview_3 = memref.subview %reinterpret_cast_1[0] [%6] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> + %subview_4 = memref.subview %alloc_2[0] [%6] [1] : memref<1024xf32> to memref> + memref.copy %subview_3, %subview_4 : memref> to memref> + %9 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32> + %10 = arith.addf %8, %9 : tensor<1024xf32> + %reinterpret_cast_5 = memref.reinterpret_cast %arg4 to offset: [%1], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1], offset: ?>> + %extracted_slice = tensor.extract_slice %10[0] [%6] [1] : tensor<1024xf32> to tensor + %subview_6 = memref.subview %reinterpret_cast_5[0] [%6] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview_6 : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/01_vector_add/triton-ascend/add_kernel.ttir b/python/test/ops/01_vector_add/triton-ascend/add_kernel.ttir new file mode 100644 index 000000000..c817f2657 --- /dev/null +++ b/python/test/ops/01_vector_add/triton-ascend/add_kernel.ttir @@ -0,0 +1,39 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":37:0) +module { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":37:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":37:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":37:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":37:0)) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32> loc(#loc1) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4) + %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5) + %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5) + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc6) + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6) + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc7) + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc7) + %9 = tt.load %8, %6, %cst : tensor<1024x!tt.ptr> loc(#loc8) + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc9) + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc9) + %12 = tt.load %11, %6, %cst : tensor<1024x!tt.ptr> loc(#loc10) + %13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc11) + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc12) + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc12) + tt.store %15, %13, %6 : tensor<1024x!tt.ptr> loc(#loc13) + tt.return loc(#loc14) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":47:24) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":52:24) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":53:41) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":53:28) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":55:21) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":58:24) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":58:16) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":59:24) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":59:16) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":60:17) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":62:26) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":62:35) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/01_vector_add/01_vector_add.py":62:4) diff --git a/python/test/ops/abs/abs.py b/python/test/ops/abs/abs.py new file mode 100644 index 000000000..1fa9de7a3 --- /dev/null +++ b/python/test/ops/abs/abs.py @@ -0,0 +1,81 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def abs_kernel( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + output = tl.abs(x) + tl.store(output_ptr + offsets, output, mask=mask) + + +def abs(x: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + with torch_device_fn.device(x.device): + abs_kernel[grid](x, output, n_elements, BLOCK_SIZE=1024) + return output + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + shape = (4, 13) + dtype = torch.float32 + + # inp + inp = torch.randn(shape, dtype=dtype, device=device) + ref_inp = inp.cpu() + + # op + ref_out = torch.abs(ref_inp) + res_out = abs(inp) + check("value", ref_out, res_out) diff --git a/python/test/ops/abs/triton-ascend/abs_kernel.ttadapter b/python/test/ops/abs/triton-ascend/abs_kernel.ttadapter new file mode 100644 index 000000000..f105b35ea --- /dev/null +++ b/python/test/ops/abs/triton-ascend/abs_kernel.ttadapter @@ -0,0 +1,31 @@ +module { + func.func @abs_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %cst = arith.constant 0.000000e+00 : f32 + %c1024 = arith.constant 1024 : index + %c1024_i32 = arith.constant 1024 : i32 + %0 = arith.muli %arg8, %c1024_i32 : i32 + %1 = arith.index_cast %0 : i32 to index + %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%1], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1], offset: ?>> + %alloc = memref.alloc() : memref<1024xf32> + %2 = arith.addi %1, %c1024 : index + %3 = arith.index_cast %arg4 : i32 to index + %4 = arith.maxsi %1, %3 : index + %5 = arith.minsi %2, %4 : index + %6 = arith.subi %5, %1 : index + %7 = arith.cmpi slt, %6, %c1024 : index + scf.if %7 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<1024xf32>) + } + %subview = memref.subview %reinterpret_cast[0] [%6] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> + %subview_0 = memref.subview %alloc[0] [%6] [1] : memref<1024xf32> to memref> + memref.copy %subview, %subview_0 : memref> to memref> + %8 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32> + %9 = math.absf %8 : tensor<1024xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %arg3 to offset: [%1], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1], offset: ?>> + %extracted_slice = tensor.extract_slice %9[0] [%6] [1] : tensor<1024xf32> to tensor + %subview_2 = memref.subview %reinterpret_cast_1[0] [%6] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview_2 : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/abs/triton-ascend/abs_kernel.ttir b/python/test/ops/abs/triton-ascend/abs_kernel.ttir new file mode 100644 index 000000000..01ee85fb1 --- /dev/null +++ b/python/test/ops/abs/triton-ascend/abs_kernel.ttir @@ -0,0 +1,34 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":17:0) +module { + tt.func public @abs_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":17:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":17:0), %arg2: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":17:0)) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32> loc(#loc1) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4) + %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5) + %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5) + %5 = tt.splat %arg2 : i32 -> tensor<1024xi32> loc(#loc6) + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6) + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc7) + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc7) + %9 = tt.load %8, %6, %cst : tensor<1024x!tt.ptr> loc(#loc8) + %10 = math.absf %9 : tensor<1024xf32> loc(#loc9) + %11 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc10) + %12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc10) + tt.store %12, %10, %6 : tensor<1024x!tt.ptr> loc(#loc11) + tt.return loc(#loc12) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":23:24) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":24:24) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":25:41) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":25:28) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":26:21) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":27:24) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":27:16) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":28:20) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":29:26) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":29:35) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/abs/abs.py":29:4) diff --git a/python/test/ops/addmm/addmm.py b/python/test/ops/addmm/addmm.py new file mode 100644 index 000000000..a5b384491 --- /dev/null +++ b/python/test/ops/addmm/addmm.py @@ -0,0 +1,163 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit(do_not_specialize=["alpha", "beta"]) +def addmm_kernel( + a_ptr, + b_ptr, + i_ptr, + c_ptr, + alpha, + beta, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_im, + stride_in, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr = 8, + BLOCK_SIZE_N: tl.constexpr = 8, + BLOCK_SIZE_K: tl.constexpr = 8, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), + other=0.0, + ) + accumulator += tl.dot(a, b, allow_tf32=False) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] + bias = tl.load(i_ptrs, mask=c_mask, other=0.0) + + accumulator = accumulator * alpha + bias * beta + c = accumulator.to(bias.dtype) + tl.store(c_ptrs, c, mask=c_mask) + + +def addmm(bias, mat1, mat2, *, beta=1, alpha=1): + assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" + M, K = mat1.shape + _, N = mat2.shape + + mat1 = mat1.contiguous() + out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + with torch_device_fn.device(mat1.device): + addmm_kernel[grid]( + mat1, + mat2, + bias, + out, + alpha, + beta, + M, + N, + K, + mat1.stride(0), + mat1.stride(1), + mat2.stride(0), + mat2.stride(1), + bias.stride(0), + bias.stride(1), + out.stride(0), + out.stride(1), + ) + return out + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + M = 4 + N = 2 + K = 8 + scalar = 0.001 + dtype = torch.float32 + b_column_major = True + + # inp + mat1 = torch.randn((M, K), dtype=dtype, device=device) + if b_column_major: + mat2 = torch.randn((N, K), dtype=dtype, device=device).t() + else: + mat2 = torch.randn((K, N), dtype=dtype, device=device) + bias2 = torch.randn((M,N), dtype=dtype, device=device) + ref_mat1 = mat1.cpu() + ref_mat2 = mat2.cpu() + ref_bias2 = bias2.cpu() + alpha = beta = scalar + + # op + ref_out2 = torch.addmm(ref_bias2, ref_mat1, ref_mat2, alpha=alpha, beta=beta) + res_out2 = addmm(bias2, mat1, mat2, alpha=alpha, beta=beta) + check("value", ref_out2, res_out2, reduce_dim=K) diff --git a/python/test/ops/addmm/addmm_ascend.py b/python/test/ops/addmm/addmm_ascend.py new file mode 100644 index 000000000..12e15b8f6 --- /dev/null +++ b/python/test/ops/addmm/addmm_ascend.py @@ -0,0 +1,165 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit(do_not_specialize=["alpha", "beta"]) +def addmm_kernel( + a_ptr, + b_ptr, + i_ptr, + c_ptr, + alpha, + beta, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_im, + stride_in, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr = 8, + BLOCK_SIZE_N: tl.constexpr = 8, + BLOCK_SIZE_K: tl.constexpr = 8, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + accumulator += tl.dot(a, b, allow_tf32=False) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] + bias = tl.load(i_ptrs, mask=c_mask, other=0.0) + bias1 = bias.to(accumulator.dtype) + accumulator = accumulator * alpha + bias1 * beta + c = accumulator.to(bias.dtype) + tl.store(c_ptrs, c, mask=c_mask) + + +def addmm(bias, mat1, mat2, *, beta=1, alpha=1): + assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" + M, K = mat1.shape + _, N = mat2.shape + + mat1 = mat1.contiguous() + mat2 = mat2.contiguous() # ascend need + out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) + bias = bias.broadcast_to(out.shape).contiguous() + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + with torch_device_fn.device(mat1.device): + addmm_kernel[grid]( + mat1, + mat2, + bias, + out, + alpha, + beta, + M, + N, + K, + mat1.stride(0), + mat1.stride(1), + mat2.stride(0), + mat2.stride(1), + bias.stride(0), + bias.stride(1), + out.stride(0), + out.stride(1), + ) + return out + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + M = 4 + N = 2 + K = 8 + scalar = 0.001 + dtype = torch.float32 + b_column_major = True + + # inp + mat1 = torch.randn((M, K), dtype=dtype, device=device) + if b_column_major: + mat2 = torch.randn((N, K), dtype=dtype, device=device).t() + else: + mat2 = torch.randn((K, N), dtype=dtype, device=device) + bias2 = torch.randn((M,N), dtype=dtype, device=device) + ref_mat1 = mat1.cpu() + ref_mat2 = mat2.cpu() + ref_bias2 = bias2.cpu() + alpha = beta = scalar + + # op + ref_out2 = torch.addmm(ref_bias2, ref_mat1, ref_mat2, alpha=alpha, beta=beta) + res_out2 = addmm(bias2, mat1, mat2, alpha=alpha, beta=beta) + check("value", ref_out2, res_out2, reduce_dim=K) diff --git a/python/test/ops/addmm/git.py b/python/test/ops/addmm/git.py new file mode 100644 index 000000000..fbcd5493b --- /dev/null +++ b/python/test/ops/addmm/git.py @@ -0,0 +1,118 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems import runtime +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import broadcastable_to, libentry +from flag_gems.utils import triton_lang_extension as tle + +logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') + + +@libentry() +@triton.autotune( + configs=runtime.get_tuned_config("addmm"), + key=["M", "N", "K"], +) +@triton.jit(do_not_specialize=["alpha", "beta"]) +def addmm_kernel( + a_ptr, + b_ptr, + i_ptr, + c_ptr, + alpha, + beta, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_im, + stride_in, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tle.program_id(0) + pid_n = tle.program_id(1) + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + accumulator += tl.dot(a, b, allow_tf32=False) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] + bias = tl.load(i_ptrs, mask=c_mask, other=0.0) + bias1 = bias.to(accumulator.dtype) + accumulator = accumulator * alpha + bias1 * beta + c = accumulator.to(bias.dtype) + tl.store(c_ptrs, c, mask=c_mask) + + +def addmm(bias, mat1, mat2, *, beta=1, alpha=1): + logger.debug("GEMS_ASCEND ADDMM") + assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" + assert broadcastable_to( + bias.shape, (mat1.shape[0], mat2.shape[1]) + ), "Incompatible input shape" + M, K = mat1.shape + _, N = mat2.shape + + mat1 = mat1.contiguous() + mat2 = mat2.contiguous() + out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) + bias = bias.broadcast_to(out.shape).contiguous() + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + with torch_device_fn.device(mat1.device): + addmm_kernel[grid]( + mat1, + mat2, + bias, + out, + alpha, + beta, + M, + N, + K, + mat1.stride(0), + mat1.stride(1), + mat2.stride(0), + mat2.stride(1), + bias.stride(0), + bias.stride(1), + out.stride(0), + out.stride(1), + ) + return out \ No newline at end of file diff --git a/python/test/ops/addmm/triton-ascend/addmm_kernel.ttadapter b/python/test/ops/addmm/triton-ascend/addmm_kernel.ttadapter new file mode 100644 index 000000000..d363492b2 --- /dev/null +++ b/python/test/ops/addmm/triton-ascend/addmm_kernel.ttadapter @@ -0,0 +1,101 @@ +module { + func.func @addmm_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg5: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg6: f32, %arg7: f32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32, %arg18: i32, %arg19: i32, %arg20: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "mix"} { + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c7_i32 = arith.constant 7 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> + %2 = arith.muli %arg18, %c8_i32 : i32 + %3 = arith.muli %arg19, %c8_i32 : i32 + %4 = arith.addi %arg10, %c7_i32 : i32 + %5 = arith.divsi %4, %c8_i32 : i32 + %6 = arith.muli %arg12, %c8_i32 : i32 + %7 = arith.index_cast %2 : i32 to index + %8 = arith.index_cast %arg11 : i32 to index + %9 = arith.muli %7, %8 : index + %10 = arith.index_cast %arg12 : i32 to index + %11 = arith.index_cast %3 : i32 to index + %12:3 = scf.for %arg21 = %c0_i32 to %5 step %c1_i32 iter_args(%arg22 = %1, %arg23 = %9, %arg24 = %c0) -> (tensor<8x8xf32>, index, index) : i32 { + %40 = arith.addi %arg24, %11 : index + %reinterpret_cast_3 = memref.reinterpret_cast %arg3 to offset: [%40], sizes: [8, 8], strides: [%10, %c1] : memref to memref<8x8xf32, strided<[?, ?], offset: ?>> + %reinterpret_cast_4 = memref.reinterpret_cast %arg2 to offset: [%arg23], sizes: [8, 8], strides: [%8, %c1] : memref to memref<8x8xf32, strided<[?, ?], offset: ?>> + %41 = arith.muli %arg21, %c8_i32 : i32 + %42 = arith.subi %arg10, %41 : i32 + %alloc_5 = memref.alloc() : memref<8x8xf32> + %43 = arith.index_cast %42 : i32 to index + %44 = arith.maxsi %43, %c0 : index + %45 = arith.minsi %44, %c8 : index + %46 = arith.cmpi slt, %45, %c8 : index + scf.if %46 { + linalg.fill ins(%cst : f32) outs(%alloc_5 : memref<8x8xf32>) + } + %subview_6 = memref.subview %reinterpret_cast_4[0, 0] [8, %45] [1, 1] : memref<8x8xf32, strided<[?, ?], offset: ?>> to memref<8x?xf32, strided<[?, ?], offset: ?>> + %subview_7 = memref.subview %alloc_5[0, 0] [8, %45] [1, 1] : memref<8x8xf32> to memref<8x?xf32, strided<[8, 1]>> + memref.copy %subview_6, %subview_7 : memref<8x?xf32, strided<[?, ?], offset: ?>> to memref<8x?xf32, strided<[8, 1]>> + annotation.mark %subview_7 {MayImplicitTransposeWithLastAxis} : memref<8x?xf32, strided<[8, 1]>> + %47 = bufferization.to_tensor %alloc_5 restrict writable : memref<8x8xf32> + annotation.mark %47 {MayImplicitTransposeWithLastAxis} : tensor<8x8xf32> + %alloc_8 = memref.alloc() : memref<8x8xf32> + scf.if %46 { + linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<8x8xf32>) + } + %subview_9 = memref.subview %reinterpret_cast_3[0, 0] [%45, 8] [1, 1] : memref<8x8xf32, strided<[?, ?], offset: ?>> to memref> + %subview_10 = memref.subview %alloc_8[0, 0] [%45, 8] [1, 1] : memref<8x8xf32> to memref> + memref.copy %subview_9, %subview_10 : memref> to memref> + annotation.mark %subview_10 {MayImplicitTransposeWithLastAxis} : memref> + %48 = bufferization.to_tensor %alloc_8 restrict writable : memref<8x8xf32> + annotation.mark %48 {MayImplicitTransposeWithLastAxis} : tensor<8x8xf32> + %49 = linalg.matmul {input_precison = "ieee"} ins(%47, %48 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%arg22 : tensor<8x8xf32>) -> tensor<8x8xf32> + %50 = arith.addi %arg23, %c8 : index + %51 = arith.index_cast %6 : i32 to index + %52 = arith.addi %arg24, %51 : index + scf.yield %49, %50, %52 : tensor<8x8xf32>, index, index + } + %13 = arith.index_cast %arg14 : i32 to index + %14 = arith.muli %7, %13 : index + %15 = arith.addi %14, %11 : index + %reinterpret_cast = memref.reinterpret_cast %arg5 to offset: [%15], sizes: [8, 8], strides: [%13, 1] : memref to memref<8x8xf32, strided<[?, 1], offset: ?>> + %16 = arith.index_cast %arg13 : i32 to index + %17 = arith.muli %7, %16 : index + %18 = arith.addi %17, %11 : index + %reinterpret_cast_0 = memref.reinterpret_cast %arg4 to offset: [%18], sizes: [8, 8], strides: [%16, 1] : memref to memref<8x8xf32, strided<[?, 1], offset: ?>> + %alloc = memref.alloc() : memref<8x8xf32> + %19 = arith.addi %7, %c8 : index + %20 = arith.index_cast %arg8 : i32 to index + %21 = arith.maxsi %7, %20 : index + %22 = arith.minsi %19, %21 : index + %23 = arith.subi %22, %7 : index + %24 = arith.addi %11, %c8 : index + %25 = arith.index_cast %arg9 : i32 to index + %26 = arith.maxsi %11, %25 : index + %27 = arith.minsi %24, %26 : index + %28 = arith.subi %27, %11 : index + %29 = arith.minsi %23, %c8 : index + %30 = arith.minsi %28, %c8 : index + %31 = arith.cmpi slt, %29, %c8 : index + %32 = arith.cmpi slt, %30, %c8 : index + %33 = arith.ori %31, %32 : i1 + scf.if %33 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<8x8xf32>) + } + %subview = memref.subview %reinterpret_cast_0[0, 0] [%29, %30] [1, 1] : memref<8x8xf32, strided<[?, 1], offset: ?>> to memref> + %subview_1 = memref.subview %alloc[0, 0] [%29, %30] [1, 1] : memref<8x8xf32> to memref> + memref.copy %subview, %subview_1 : memref> to memref> + %34 = bufferization.to_tensor %alloc restrict writable : memref<8x8xf32> + %35 = linalg.fill ins(%arg6 : f32) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> + %36 = arith.mulf %12#0, %35 : tensor<8x8xf32> + %37 = linalg.fill ins(%arg7 : f32) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> + %38 = arith.mulf %34, %37 : tensor<8x8xf32> + %39 = arith.addf %36, %38 : tensor<8x8xf32> + %extracted_slice = tensor.extract_slice %39[0, 0] [%29, %30] [1, 1] : tensor<8x8xf32> to tensor + %subview_2 = memref.subview %reinterpret_cast[0, 0] [%29, %30] [1, 1] : memref<8x8xf32, strided<[?, 1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview_2 : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/addmm/triton-ascend/addmm_kernel.ttir b/python/test/ops/addmm/triton-ascend/addmm_kernel.ttir new file mode 100644 index 000000000..aec54a3d3 --- /dev/null +++ b/python/test/ops/addmm/triton-ascend/addmm_kernel.ttir @@ -0,0 +1,135 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0) +module { + tt.func public @addmm_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg3: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg4: f32 loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg5: f32 loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg6: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg7: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg8: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg9: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg10: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg11: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0), %arg12: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":17:0)) attributes {noinline = false} { + %c7_i32 = arith.constant 7 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst = arith.constant dense<8> : tensor<8x8xi32> loc(#loc1) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<8x8xf32> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = tt.get_program_id y : i32 loc(#loc3) + %2 = arith.muli %0, %c8_i32 : i32 loc(#loc4) + %3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc5) + %4 = tt.splat %2 : i32 -> tensor<8xi32> loc(#loc6) + %5 = arith.addi %4, %3 : tensor<8xi32> loc(#loc6) + %6 = arith.muli %1, %c8_i32 : i32 loc(#loc7) + %7 = tt.splat %6 : i32 -> tensor<8xi32> loc(#loc8) + %8 = arith.addi %7, %3 : tensor<8xi32> loc(#loc8) + %9 = tt.expand_dims %5 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc9) + %10 = tt.splat %arg9 : i32 -> tensor<8x1xi32> loc(#loc10) + %11 = arith.muli %9, %10 : tensor<8x1xi32> loc(#loc10) + %12 = tt.expand_dims %3 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> loc(#loc11) + %13 = tt.broadcast %11 : tensor<8x1xi32> -> tensor<8x8xi32> loc(#loc12) + %14 = tt.broadcast %12 : tensor<1x8xi32> -> tensor<8x8xi32> loc(#loc12) + %15 = arith.addi %13, %14 : tensor<8x8xi32> loc(#loc12) + %16 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x!tt.ptr> loc(#loc13) + %17 = tt.addptr %16, %15 : tensor<8x8x!tt.ptr>, tensor<8x8xi32> loc(#loc13) + %18 = tt.expand_dims %3 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc14) + %19 = tt.splat %arg10 : i32 -> tensor<8x1xi32> loc(#loc15) + %20 = arith.muli %18, %19 : tensor<8x1xi32> loc(#loc15) + %21 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> loc(#loc16) + %22 = tt.broadcast %20 : tensor<8x1xi32> -> tensor<8x8xi32> loc(#loc17) + %23 = tt.broadcast %21 : tensor<1x8xi32> -> tensor<8x8xi32> loc(#loc17) + %24 = arith.addi %22, %23 : tensor<8x8xi32> loc(#loc17) + %25 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x!tt.ptr> loc(#loc18) + %26 = tt.addptr %25, %24 : tensor<8x8x!tt.ptr>, tensor<8x8xi32> loc(#loc18) + %27 = arith.addi %arg8, %c7_i32 : i32 loc(#loc49) + %28 = arith.divsi %27, %c8_i32 : i32 loc(#loc50) + %29 = arith.muli %arg10, %c8_i32 : i32 loc(#loc22) + %30 = tt.splat %29 : i32 -> tensor<8x8xi32> loc(#loc23) + %31:3 = scf.for %arg13 = %c0_i32 to %28 step %c1_i32 iter_args(%arg14 = %cst_0, %arg15 = %17, %arg16 = %26) -> (tensor<8x8xf32>, tensor<8x8x!tt.ptr>, tensor<8x8x!tt.ptr>) : i32 { + %57 = arith.muli %arg13, %c8_i32 : i32 loc(#loc25) + %58 = arith.subi %arg8, %57 : i32 loc(#loc26) + %59 = tt.splat %58 : i32 -> tensor<1x8xi32> loc(#loc27) + %60 = arith.cmpi slt, %12, %59 : tensor<1x8xi32> loc(#loc27) + %61 = tt.broadcast %60 : tensor<1x8xi1> -> tensor<8x8xi1> loc(#loc28) + %62 = tt.load %arg15, %61, %cst_0 : tensor<8x8x!tt.ptr> loc(#loc28) + %63 = tt.splat %58 : i32 -> tensor<8x1xi32> loc(#loc29) + %64 = arith.cmpi slt, %18, %63 : tensor<8x1xi32> loc(#loc29) + %65 = tt.broadcast %64 : tensor<8x1xi1> -> tensor<8x8xi1> loc(#loc30) + %66 = tt.load %arg16, %65, %cst_0 : tensor<8x8x!tt.ptr> loc(#loc30) + %67 = tt.dot %62, %66, %arg14 : tensor<8x8xf32> * tensor<8x8xf32> -> tensor<8x8xf32> loc(#loc31) + %68 = tt.addptr %arg15, %cst : tensor<8x8x!tt.ptr>, tensor<8x8xi32> loc(#loc32) + %69 = tt.addptr %arg16, %30 : tensor<8x8x!tt.ptr>, tensor<8x8xi32> loc(#loc23) + scf.yield %67, %68, %69 : tensor<8x8xf32>, tensor<8x8x!tt.ptr>, tensor<8x8x!tt.ptr> loc(#loc33) + } loc(#loc24) + %32 = tt.splat %arg12 : i32 -> tensor<8x1xi32> loc(#loc34) + %33 = arith.muli %32, %9 : tensor<8x1xi32> loc(#loc34) + %34 = tt.splat %arg3 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc35) + %35 = tt.addptr %34, %33 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc35) + %36 = tt.broadcast %35 : tensor<8x1x!tt.ptr> -> tensor<8x8x!tt.ptr> loc(#loc36) + %37 = tt.addptr %36, %23 : tensor<8x8x!tt.ptr>, tensor<8x8xi32> loc(#loc36) + %38 = tt.splat %arg6 : i32 -> tensor<8x1xi32> loc(#loc37) + %39 = arith.cmpi slt, %9, %38 : tensor<8x1xi32> loc(#loc37) + %40 = tt.splat %arg7 : i32 -> tensor<1x8xi32> loc(#loc38) + %41 = arith.cmpi slt, %21, %40 : tensor<1x8xi32> loc(#loc38) + %42 = tt.broadcast %39 : tensor<8x1xi1> -> tensor<8x8xi1> loc(#loc39) + %43 = tt.broadcast %41 : tensor<1x8xi1> -> tensor<8x8xi1> loc(#loc39) + %44 = arith.andi %42, %43 : tensor<8x8xi1> loc(#loc39) + %45 = tt.splat %arg11 : i32 -> tensor<8x1xi32> loc(#loc40) + %46 = arith.muli %45, %9 : tensor<8x1xi32> loc(#loc40) + %47 = tt.splat %arg2 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc41) + %48 = tt.addptr %47, %46 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc41) + %49 = tt.broadcast %48 : tensor<8x1x!tt.ptr> -> tensor<8x8x!tt.ptr> loc(#loc42) + %50 = tt.addptr %49, %23 : tensor<8x8x!tt.ptr>, tensor<8x8xi32> loc(#loc42) + %51 = tt.load %50, %44, %cst_0 : tensor<8x8x!tt.ptr> loc(#loc43) + %52 = tt.splat %arg4 : f32 -> tensor<8x8xf32> loc(#loc44) + %53 = arith.mulf %31#0, %52 : tensor<8x8xf32> loc(#loc44) + %54 = tt.splat %arg5 : f32 -> tensor<8x8xf32> loc(#loc45) + %55 = arith.mulf %51, %54 : tensor<8x8xf32> loc(#loc45) + %56 = arith.addf %53, %55 : tensor<8x8xf32> loc(#loc46) + tt.store %37, %56, %44 : tensor<8x8x!tt.ptr> loc(#loc47) + tt.return loc(#loc48) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":39:26) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":40:26) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":42:22) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":42:50) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":42:37) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":43:22) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":43:37) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":45:30) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":45:41) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":45:60) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":45:53) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":45:22) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":46:29) +#loc15 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":46:40) +#loc16 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":46:60) +#loc17 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":46:52) +#loc18 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":46:22) +#loc19 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":40:22) +#loc20 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":49:33) +#loc21 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":40:28) +#loc22 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":62:33) +#loc23 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":62:18) +#loc24 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":49:22) +#loc25 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":52:44) +#loc26 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":52:40) +#loc27 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":52:36) +#loc28 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":51:12) +#loc29 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":57:36) +#loc30 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":56:12) +#loc31 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":60:33) +#loc32 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":61:18) +#loc33 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":62:8) +#loc34 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":66:33) +#loc35 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":66:21) +#loc36 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":66:52) +#loc37 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":67:33) +#loc38 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":67:58) +#loc39 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":67:39) +#loc40 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":68:33) +#loc41 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":68:21) +#loc42 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":68:52) +#loc43 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":69:19) +#loc44 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":71:32) +#loc45 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":71:48) +#loc46 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":71:40) +#loc47 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":73:21) +#loc48 = loc("/home/zhengyang/git/flagtree/python/test/ops/addmm/addmm_ascend.py":73:4) +#loc49 = loc(callsite(#loc19 at #loc20)) +#loc50 = loc(callsite(#loc21 at #loc20)) diff --git a/python/test/ops/amax/amax.py b/python/test/ops/amax/amax.py new file mode 100644 index 000000000..2feded917 --- /dev/null +++ b/python/test/ops/amax/amax.py @@ -0,0 +1,130 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def get_dtype_min(dtype): + """get a value which is less that all other values of that dtype""" + dtype_ = dtype.value # tl.dtype + if dtype_.is_floating(): + value: tl.constexpr = float("-inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = -1 * 2 ** (width - 1) + return value + if dtype_.is_int_unsigned(): + value: tl.constexpr = 0 + return value + + +@triton.jit +def amax_kernel( + inp, + out, + M, + N, + BLOCK_M: tl.constexpr = 8, + BLOCK_N: tl.constexpr = 256, +): + dtype = inp.type.element_ty + min_value = get_dtype_min(dtype) + + # Map the program id to the row of inp it should compute. + pid = tl.program_id(0) + rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + inp = inp + rows * N + out = out + rows + row_mask = rows < M + + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + _all = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type) + for off in range(0, N, BLOCK_N): + cols = off + tl.arange(0, BLOCK_N)[None, :] + col_mask = cols < N + mask = row_mask and col_mask + a = tl.load(inp + cols, mask, other=min_value) + _all = tl.maximum(_all, a) + all = tl.max(_all, axis=1)[:, None] + tl.store(out, all, row_mask) + + +def amax(inp, dim=None, keepdim=False): + if dim is not None: + if isinstance(dim, int): + dim = [dim] + assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" + dtype = inp.dtype + + shape = list(inp.shape) + dim = [d % inp.ndim for d in dim] + N = 1 + for i in dim: + N *= shape[i] + shape[i] = 1 + M = inp.numel() // N + + out = torch.empty(shape, dtype=dtype, device=inp.device) + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + with torch_device_fn.device(inp.device): + amax_kernel[grid](inp, out, M, N) + if not keepdim: + out = out.squeeze(dim=dim) + return out + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + keepdim = True + dim = 1 + shape = (1, 32) + dtype = torch.float32 + + # inp + inp = torch.randn(shape, dtype=dtype, device=device) + ref_inp = inp.cpu() + + # op + ref_out = torch.amax(ref_inp, dim=dim, keepdim=keepdim) + res_out = amax(inp, dim=dim, keepdim=keepdim) + check("value", ref_out, res_out) diff --git a/python/test/ops/amax/amax_ascend_perf.py b/python/test/ops/amax/amax_ascend_perf.py new file mode 100644 index 000000000..917a3566e --- /dev/null +++ b/python/test/ops/amax/amax_ascend_perf.py @@ -0,0 +1,137 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def get_dtype_min(dtype): + """get a value which is less that all other values of that dtype""" + dtype_ = dtype.value # tl.dtype + if dtype_.is_floating(): + value: tl.constexpr = float("-inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = -1 * 2 ** (width - 1) + return value + if dtype_.is_int_unsigned(): + value: tl.constexpr = 0 + return value + + +@triton.jit +def amax_kernel( + inp, + out, + M, + N, + BLOCK_M: tl.constexpr = 8, + BLOCK_N: tl.constexpr = 256, +): + dtype = inp.type.element_ty + min_value = get_dtype_min(dtype) + + # Map the program id to the row of inp it should compute. + workers = tl.num_programs(0) + pid = tl.program_id(0) + + total_workloads = tl.cdiv(M, BLOCK_M) + workloads = tl.cdiv(total_workloads, workers) + + for w in range(workloads): + work_id = pid + w * workers + rows = work_id * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + ninp = inp + rows * N + nout = out + rows + row_mask = rows < M + + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + _all = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type) + for off in range(0, N, BLOCK_N): + cols = off + tl.arange(0, BLOCK_N)[None, :] + col_mask = cols < N + mask = row_mask and col_mask + a = tl.load(ninp + cols, mask, other=min_value) + _all = tl.maximum(_all, a) + all = tl.max(_all, axis=1)[:, None] + tl.store(nout, all, row_mask) + + +def amax(inp, dim=None, keepdim=False): + if dim is not None: + if isinstance(dim, int): + dim = [dim] + assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" + dtype = inp.dtype + + shape = list(inp.shape) + dim = [d % inp.ndim for d in dim] + N = 1 + for i in dim: + N *= shape[i] + shape[i] = 1 + M = inp.numel() // N + + out = torch.empty(shape, dtype=dtype, device=inp.device) + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + with torch_device_fn.device(inp.device): + amax_kernel[grid](inp, out, M, N) + if not keepdim: + out = out.squeeze(dim=dim) + return out + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + keepdim = True + dim = 1 + shape = (1, 32) + dtype = torch.float32 + + # inp + inp = torch.randn(shape, dtype=dtype, device=device) + ref_inp = inp.cpu() + + # op + ref_out = torch.amax(ref_inp, dim=dim, keepdim=keepdim) + res_out = amax(inp, dim=dim, keepdim=keepdim) + check("value", ref_out, res_out) diff --git a/python/test/ops/amax/triton-ascend-perf/amax_kernel.ttadapter b/python/test/ops/amax/triton-ascend-perf/amax_kernel.ttadapter new file mode 100644 index 000000000..9d2588dfc --- /dev/null +++ b/python/test/ops/amax/triton-ascend-perf/amax_kernel.ttadapter @@ -0,0 +1,69 @@ +module { + func.func @amax_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<8x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8x256xf32>) -> tensor<8x256xf32> + %2 = arith.divsi %arg5, %arg5 : i32 + scf.for %arg11 = %c0_i32 to %2 step %c1_i32 : i32 { + %3 = arith.muli %arg11, %arg5 : i32 + %4 = arith.addi %arg8, %3 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.index_cast %5 : i32 to index + %7 = arith.index_cast %arg4 : i32 to index + %8 = arith.muli %6, %7 : index + %reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%6], sizes: [8, 1], strides: [1, 1] : memref to memref<8x1xf32, strided<[1, 1], offset: ?>> + %9 = scf.for %arg12 = %c0_i32 to %arg4 step %c256_i32 iter_args(%arg13 = %1) -> (tensor<8x256xf32>) : i32 { + %16 = arith.index_cast %arg12 : i32 to index + %17 = arith.addi %8, %16 : index + %reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [%17], sizes: [8, 256], strides: [%7, 1] : memref to memref<8x256xf32, strided<[?, 1], offset: ?>> + %alloc = memref.alloc() : memref<8x256xf32> + %18 = arith.addi %6, %c8 : index + %19 = arith.maxsi %6, %c1 : index + %20 = arith.minsi %18, %19 : index + %21 = arith.subi %20, %6 : index + %22 = arith.addi %16, %c256 : index + %23 = arith.maxsi %16, %7 : index + %24 = arith.minsi %22, %23 : index + %25 = arith.subi %24, %16 : index + %26 = arith.minsi %21, %c8 : index + %27 = arith.minsi %25, %c256 : index + %28 = arith.cmpi slt, %26, %c8 : index + %29 = arith.cmpi slt, %27, %c256 : index + %30 = arith.ori %28, %29 : i1 + scf.if %30 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<8x256xf32>) + } + %subview_1 = memref.subview %reinterpret_cast_0[0, 0] [%26, %27] [1, 1] : memref<8x256xf32, strided<[?, 1], offset: ?>> to memref> + %subview_2 = memref.subview %alloc[0, 0] [%26, %27] [1, 1] : memref<8x256xf32> to memref> + memref.copy %subview_1, %subview_2 : memref> to memref> + %31 = bufferization.to_tensor %alloc restrict writable : memref<8x256xf32> + %32 = arith.maxnumf %arg13, %31 : tensor<8x256xf32> + scf.yield %32 : tensor<8x256xf32> + } + %10 = tensor.empty() : tensor<8xf32> + %11 = linalg.fill ins(%cst : f32) outs(%10 : tensor<8xf32>) -> tensor<8xf32> + %reduced = linalg.reduce ins(%9 : tensor<8x256xf32>) outs(%11 : tensor<8xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %16 = arith.maxnumf %in, %init : f32 + linalg.yield %16 : f32 + } + %expanded = tensor.expand_shape %reduced [[0, 1]] output_shape [8, 1] : tensor<8xf32> into tensor<8x1xf32> + %12 = arith.addi %6, %c8 : index + %13 = arith.maxsi %6, %c1 : index + %14 = arith.minsi %12, %13 : index + %15 = arith.subi %14, %6 : index + %extracted_slice = tensor.extract_slice %expanded[0, 0] [%15, 1] [1, 1] : tensor<8x1xf32> to tensor + %subview = memref.subview %reinterpret_cast[0, 0] [%15, 1] [1, 1] : memref<8x1xf32, strided<[1, 1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor, memref>) -> () + } + return + } +} + diff --git a/python/test/ops/amax/triton-ascend-perf/amax_kernel.ttir b/python/test/ops/amax/triton-ascend-perf/amax_kernel.ttir new file mode 100644 index 000000000..5e4b1ae3b --- /dev/null +++ b/python/test/ops/amax/triton-ascend-perf/amax_kernel.ttir @@ -0,0 +1,92 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":33:0) +#loc1 = loc(unknown) +#loc28 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":66:21) +#loc35 = loc(callsite(#loc1 at #loc28)) +module { + tt.func public @amax_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":33:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":33:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":33:0)) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<8x256xf32> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_0 = arith.constant dense<1> : tensor<8x1xi32> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_num_programs x : i32 loc(#loc2) + %1 = tt.get_program_id x : i32 loc(#loc3) + %2 = arith.divsi %0, %0 : i32 loc(#loc33) + %3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc6) + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc7) + %5 = tt.splat %arg2 : i32 -> tensor<8x1xi32> loc(#loc8) + %6 = tt.splat %arg0 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc9) + %7 = tt.splat %arg1 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc10) + %8 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc11) + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc12) + %10 = tt.splat %arg2 : i32 -> tensor<1x256xi32> loc(#loc13) + scf.for %arg3 = %c0_i32 to %2 step %c1_i32 : i32 { + %11 = arith.muli %arg3, %0 : i32 loc(#loc15) + %12 = arith.addi %1, %11 : i32 loc(#loc16) + %13 = arith.muli %12, %c8_i32 : i32 loc(#loc17) + %14 = tt.splat %13 : i32 -> tensor<8x1xi32> loc(#loc18) + %15 = arith.addi %14, %4 : tensor<8x1xi32> loc(#loc18) + %16 = arith.muli %15, %5 : tensor<8x1xi32> loc(#loc8) + %17 = tt.addptr %6, %16 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc9) + %18 = tt.addptr %7, %15 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc10) + %19 = arith.cmpi slt, %15, %cst_0 : tensor<8x1xi32> loc(#loc19) + %20 = tt.broadcast %19 : tensor<8x1xi1> -> tensor<8x256xi1> loc(#loc20) + %21 = tt.broadcast %17 : tensor<8x1x!tt.ptr> -> tensor<8x256x!tt.ptr> loc(#loc21) + %22 = scf.for %arg4 = %c0_i32 to %arg2 step %c256_i32 iter_args(%arg5 = %cst) -> (tensor<8x256xf32>) : i32 { + %25 = tt.splat %arg4 : i32 -> tensor<1x256xi32> loc(#loc23) + %26 = arith.addi %25, %9 : tensor<1x256xi32> loc(#loc23) + %27 = arith.cmpi slt, %26, %10 : tensor<1x256xi32> loc(#loc13) + %28 = tt.broadcast %27 : tensor<1x256xi1> -> tensor<8x256xi1> loc(#loc20) + %29 = arith.andi %20, %28 : tensor<8x256xi1> loc(#loc20) + %30 = tt.broadcast %26 : tensor<1x256xi32> -> tensor<8x256xi32> loc(#loc21) + %31 = tt.addptr %21, %30 : tensor<8x256x!tt.ptr>, tensor<8x256xi32> loc(#loc21) + %32 = tt.load %31, %29, %cst : tensor<8x256x!tt.ptr> loc(#loc24) + %33 = arith.maxnumf %arg5, %32 : tensor<8x256xf32> loc(#loc25) + scf.yield %33 : tensor<8x256xf32> loc(#loc26) + } loc(#loc22) + %23 = "tt.reduce"(%22) <{axis = 1 : i32}> ({ + ^bb0(%arg4: f32 loc(callsite(#loc1 at #loc28)), %arg5: f32 loc(callsite(#loc1 at #loc28))): + %25 = arith.maxnumf %arg4, %arg5 : f32 loc(#loc37) + tt.reduce.return %25 : f32 loc(#loc34) + }) : (tensor<8x256xf32>) -> tensor<8xf32> loc(#loc34) + %24 = tt.expand_dims %23 {axis = 1 : i32} : tensor<8xf32> -> tensor<8x1xf32> loc(#loc30) + tt.store %18, %24, %19 : tensor<8x1x!tt.ptr> loc(#loc31) + } loc(#loc14) + tt.return loc(#loc32) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":45:30) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":46:24) +#loc4 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":40:28) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":49:41) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":53:48) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":53:57) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":54:28) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":54:21) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":55:21) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":61:38) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":61:47) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":62:30) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":51:19) +#loc15 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":52:28) +#loc16 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":52:24) +#loc17 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":53:25) +#loc18 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":53:35) +#loc19 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":56:26) +#loc20 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":63:32) +#loc21 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":64:31) +#loc22 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":60:31) +#loc23 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":61:25) +#loc24 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":64:37) +#loc25 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":65:36) +#loc26 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":65:12) +#loc27 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":184:40) +#loc29 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":163:27) +#loc30 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":66:35) +#loc31 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":67:28) +#loc32 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax_ascend_perf.py":51:4) +#loc33 = loc(callsite(#loc4 at #loc5)) +#loc34 = loc(callsite(#loc27 at #loc28)) +#loc36 = loc(callsite(#loc29 at #loc27)) +#loc37 = loc(callsite(#loc36 at #loc28)) diff --git a/python/test/ops/amax/triton-ascend/amax_kernel.ttadapter b/python/test/ops/amax/triton-ascend/amax_kernel.ttadapter new file mode 100644 index 000000000..a7d35893b --- /dev/null +++ b/python/test/ops/amax/triton-ascend/amax_kernel.ttadapter @@ -0,0 +1,63 @@ +module { + func.func @amax_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<8x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8x256xf32>) -> tensor<8x256xf32> + %2 = arith.muli %arg8, %c8_i32 : i32 + %3 = arith.index_cast %2 : i32 to index + %4 = arith.index_cast %arg4 : i32 to index + %5 = arith.muli %3, %4 : index + %reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%3], sizes: [8, 1], strides: [1, 1] : memref to memref<8x1xf32, strided<[1, 1], offset: ?>> + %6 = scf.for %arg11 = %c0_i32 to %arg4 step %c256_i32 iter_args(%arg12 = %1) -> (tensor<8x256xf32>) : i32 { + %13 = arith.index_cast %arg11 : i32 to index + %14 = arith.addi %5, %13 : index + %reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [%14], sizes: [8, 256], strides: [%4, 1] : memref to memref<8x256xf32, strided<[?, 1], offset: ?>> + %alloc = memref.alloc() : memref<8x256xf32> + %15 = arith.addi %3, %c8 : index + %16 = arith.maxsi %3, %c1 : index + %17 = arith.minsi %15, %16 : index + %18 = arith.subi %17, %3 : index + %19 = arith.addi %13, %c256 : index + %20 = arith.maxsi %13, %4 : index + %21 = arith.minsi %19, %20 : index + %22 = arith.subi %21, %13 : index + %23 = arith.minsi %18, %c8 : index + %24 = arith.minsi %22, %c256 : index + %25 = arith.cmpi slt, %23, %c8 : index + %26 = arith.cmpi slt, %24, %c256 : index + %27 = arith.ori %25, %26 : i1 + scf.if %27 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<8x256xf32>) + } + %subview_1 = memref.subview %reinterpret_cast_0[0, 0] [%23, %24] [1, 1] : memref<8x256xf32, strided<[?, 1], offset: ?>> to memref> + %subview_2 = memref.subview %alloc[0, 0] [%23, %24] [1, 1] : memref<8x256xf32> to memref> + memref.copy %subview_1, %subview_2 : memref> to memref> + %28 = bufferization.to_tensor %alloc restrict writable : memref<8x256xf32> + %29 = arith.maxnumf %arg12, %28 : tensor<8x256xf32> + scf.yield %29 : tensor<8x256xf32> + } + %7 = tensor.empty() : tensor<8xf32> + %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8xf32>) -> tensor<8xf32> + %reduced = linalg.reduce ins(%6 : tensor<8x256xf32>) outs(%8 : tensor<8xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %13 = arith.maxnumf %in, %init : f32 + linalg.yield %13 : f32 + } + %expanded = tensor.expand_shape %reduced [[0, 1]] output_shape [8, 1] : tensor<8xf32> into tensor<8x1xf32> + %9 = arith.addi %3, %c8 : index + %10 = arith.maxsi %3, %c1 : index + %11 = arith.minsi %9, %10 : index + %12 = arith.subi %11, %3 : index + %extracted_slice = tensor.extract_slice %expanded[0, 0] [%12, 1] [1, 1] : tensor<8x1xf32> to tensor + %subview = memref.subview %reinterpret_cast[0, 0] [%12, 1] [1, 1] : memref<8x1xf32, strided<[1, 1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/amax/triton-ascend/amax_kernel.ttir b/python/test/ops/amax/triton-ascend/amax_kernel.ttir new file mode 100644 index 000000000..18dbc06a3 --- /dev/null +++ b/python/test/ops/amax/triton-ascend/amax_kernel.ttir @@ -0,0 +1,78 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":33:0) +#loc1 = loc(unknown) +#loc22 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":59:17) +#loc28 = loc(callsite(#loc1 at #loc22)) +module { + tt.func public @amax_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":33:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":33:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":33:0)) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<8x256xf32> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_0 = arith.constant dense<1> : tensor<8x1xi32> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c8_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc4) + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc5) + %4 = tt.splat %1 : i32 -> tensor<8x1xi32> loc(#loc6) + %5 = arith.addi %4, %3 : tensor<8x1xi32> loc(#loc6) + %6 = tt.splat %arg2 : i32 -> tensor<8x1xi32> loc(#loc7) + %7 = arith.muli %5, %6 : tensor<8x1xi32> loc(#loc7) + %8 = tt.splat %arg0 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc8) + %9 = tt.addptr %8, %7 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc8) + %10 = tt.splat %arg1 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc9) + %11 = tt.addptr %10, %5 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc9) + %12 = arith.cmpi slt, %5, %cst_0 : tensor<8x1xi32> loc(#loc10) + %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc11) + %14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc12) + %15 = tt.splat %arg2 : i32 -> tensor<1x256xi32> loc(#loc13) + %16 = tt.broadcast %12 : tensor<8x1xi1> -> tensor<8x256xi1> loc(#loc14) + %17 = tt.broadcast %9 : tensor<8x1x!tt.ptr> -> tensor<8x256x!tt.ptr> loc(#loc15) + %18 = scf.for %arg3 = %c0_i32 to %arg2 step %c256_i32 iter_args(%arg4 = %cst) -> (tensor<8x256xf32>) : i32 { + %21 = tt.splat %arg3 : i32 -> tensor<1x256xi32> loc(#loc17) + %22 = arith.addi %21, %14 : tensor<1x256xi32> loc(#loc17) + %23 = arith.cmpi slt, %22, %15 : tensor<1x256xi32> loc(#loc13) + %24 = tt.broadcast %23 : tensor<1x256xi1> -> tensor<8x256xi1> loc(#loc14) + %25 = arith.andi %16, %24 : tensor<8x256xi1> loc(#loc14) + %26 = tt.broadcast %22 : tensor<1x256xi32> -> tensor<8x256xi32> loc(#loc15) + %27 = tt.addptr %17, %26 : tensor<8x256x!tt.ptr>, tensor<8x256xi32> loc(#loc15) + %28 = tt.load %27, %25, %cst : tensor<8x256x!tt.ptr> loc(#loc18) + %29 = arith.maxnumf %arg4, %28 : tensor<8x256xf32> loc(#loc19) + scf.yield %29 : tensor<8x256xf32> loc(#loc20) + } loc(#loc16) + %19 = "tt.reduce"(%18) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc1 at #loc22)), %arg4: f32 loc(callsite(#loc1 at #loc22))): + %21 = arith.maxnumf %arg3, %arg4 : f32 loc(#loc30) + tt.reduce.return %21 : f32 loc(#loc27) + }) : (tensor<8x256xf32>) -> tensor<8xf32> loc(#loc27) + %20 = tt.expand_dims %19 {axis = 1 : i32} : tensor<8xf32> -> tensor<8x1xf32> loc(#loc24) + tt.store %11, %20, %12 : tensor<8x1x!tt.ptr> loc(#loc25) + tt.return loc(#loc26) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":45:24) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":46:17) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":46:40) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":46:49) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":46:27) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":47:23) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":47:16) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":48:16) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":49:22) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":54:34) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":54:43) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":55:26) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":56:28) +#loc15 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":57:26) +#loc16 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":53:27) +#loc17 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":54:21) +#loc18 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":57:32) +#loc19 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":58:32) +#loc20 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":58:8) +#loc21 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":184:40) +#loc23 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":163:27) +#loc24 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":59:31) +#loc25 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":60:23) +#loc26 = loc("/home/zhengyang/git/flagtree/python/test/ops/amax/amax.py":60:4) +#loc27 = loc(callsite(#loc21 at #loc22)) +#loc29 = loc(callsite(#loc23 at #loc21)) +#loc30 = loc(callsite(#loc29 at #loc22)) diff --git a/python/test/ops/angle/angle.py b/python/test/ops/angle/angle.py new file mode 100644 index 000000000..b9adbe948 --- /dev/null +++ b/python/test/ops/angle/angle.py @@ -0,0 +1,91 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def angle_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + real = tl.load(x_ptr + offsets, mask=mask) + imag = tl.load(y_ptr + offsets, mask=mask) + real_last, imag_last = ( + (real.to(tl.float32), imag.to(tl.float32)) + if real.dtype == tl.float16 + else (real, imag) + ) + output = tl.math.atan2(imag_last, real_last) + tl.store(output_ptr + offsets, output, mask=mask) + + +def angle(input_tensor: torch.Tensor) -> torch.Tensor: + if input_tensor.dtype == torch.complex32 or input_tensor.dtype == torch.complex64: + x = input_tensor.real + y = input_tensor.imag + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + with torch_device_fn.device(x.device): + angle_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + return output + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + shape = (4, 13) + dtype = torch.complex32 + + # inp + inp = torch.randn(shape, dtype=dtype, device="cpu").to(device) + ref_inp = inp.cpu() + + # op + #ref_out = torch.angle(ref_inp) + res_out = angle(inp) + #check("value", ref_out, res_out) diff --git a/python/test/ops/angle/triton-ascend/angle_kernel.ttadapter b/python/test/ops/angle/triton-ascend/angle_kernel.ttadapter new file mode 100644 index 000000000..2476c3b57 --- /dev/null +++ b/python/test/ops/angle/triton-ascend/angle_kernel.ttadapter @@ -0,0 +1,72 @@ +module { + func.func private @__hmf_atanf(f32) -> f32 attributes {llvm.readnone} + func.func @angle_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.tensor_kind = 0 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %cst = arith.constant 0.000000e+00 : f16 + %c1024 = arith.constant 1024 : index + %c1024_i32 = arith.constant 1024 : i32 + %cst_0 = arith.constant -3.14159274 : f32 + %cst_1 = arith.constant 3.14159274 : f32 + %cst_2 = arith.constant -1.57079637 : f32 + %cst_3 = arith.constant 1.57079637 : f32 + %cst_4 = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<1024xf32> + %1 = linalg.fill ins(%cst_4 : f32) outs(%0 : tensor<1024xf32>) -> tensor<1024xf32> + %2 = linalg.fill ins(%cst_3 : f32) outs(%0 : tensor<1024xf32>) -> tensor<1024xf32> + %3 = linalg.fill ins(%cst_2 : f32) outs(%0 : tensor<1024xf32>) -> tensor<1024xf32> + %4 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<1024xf32>) -> tensor<1024xf32> + %5 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<1024xf32>) -> tensor<1024xf32> + %6 = arith.muli %arg9, %c1024_i32 : i32 + %7 = arith.index_cast %6 : i32 to index + %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%7], sizes: [1024], strides: [1] : memref to memref<1024xf16, strided<[1], offset: ?>> + %alloc = memref.alloc() : memref<1024xf16> + %8 = arith.addi %7, %c1024 : index + %9 = arith.index_cast %arg5 : i32 to index + %10 = arith.maxsi %7, %9 : index + %11 = arith.minsi %8, %10 : index + %12 = arith.subi %11, %7 : index + %13 = arith.cmpi slt, %12, %c1024 : index + scf.if %13 { + linalg.fill ins(%cst : f16) outs(%alloc : memref<1024xf16>) + } + %subview = memref.subview %reinterpret_cast[0] [%12] [1] : memref<1024xf16, strided<[1], offset: ?>> to memref> + %subview_5 = memref.subview %alloc[0] [%12] [1] : memref<1024xf16> to memref> + memref.copy %subview, %subview_5 : memref> to memref> + %14 = bufferization.to_tensor %alloc restrict writable : memref<1024xf16> + %reinterpret_cast_6 = memref.reinterpret_cast %arg3 to offset: [%7], sizes: [1024], strides: [1] : memref to memref<1024xf16, strided<[1], offset: ?>> + %alloc_7 = memref.alloc() : memref<1024xf16> + scf.if %13 { + linalg.fill ins(%cst : f16) outs(%alloc_7 : memref<1024xf16>) + } + %subview_8 = memref.subview %reinterpret_cast_6[0] [%12] [1] : memref<1024xf16, strided<[1], offset: ?>> to memref> + %subview_9 = memref.subview %alloc_7[0] [%12] [1] : memref<1024xf16> to memref> + memref.copy %subview_8, %subview_9 : memref> to memref> + %15 = bufferization.to_tensor %alloc_7 restrict writable : memref<1024xf16> + %16 = arith.extf %14 : tensor<1024xf16> to tensor<1024xf32> + %17 = arith.extf %15 : tensor<1024xf16> to tensor<1024xf32> + %18 = arith.cmpf oeq, %16, %1 : tensor<1024xf32> + %19 = arith.divf %17, %16 : tensor<1024xf32> + %mapped = linalg.map { func.call {callee = @__hmf_atanf} } ins(%19 : tensor<1024xf32>) outs(%19 : tensor<1024xf32>) + %20 = arith.select %18, %1, %mapped : tensor<1024xi1>, tensor<1024xf32> + %21 = arith.cmpf ogt, %17, %1 : tensor<1024xf32> + %22 = arith.andi %18, %21 : tensor<1024xi1> + %23 = arith.select %22, %2, %20 : tensor<1024xi1>, tensor<1024xf32> + %24 = arith.cmpf olt, %17, %1 : tensor<1024xf32> + %25 = arith.andi %18, %24 : tensor<1024xi1> + %26 = arith.select %25, %3, %23 : tensor<1024xi1>, tensor<1024xf32> + %27 = arith.cmpf olt, %16, %1 : tensor<1024xf32> + %28 = arith.cmpf oge, %17, %1 : tensor<1024xf32> + %29 = arith.andi %27, %28 : tensor<1024xi1> + %30 = arith.select %29, %4, %1 : tensor<1024xi1>, tensor<1024xf32> + %31 = arith.andi %27, %24 : tensor<1024xi1> + %32 = arith.select %31, %5, %1 : tensor<1024xi1>, tensor<1024xf32> + %33 = arith.addf %26, %30 : tensor<1024xf32> + %34 = arith.addf %33, %32 : tensor<1024xf32> + %reinterpret_cast_10 = memref.reinterpret_cast %arg4 to offset: [%7], sizes: [1024], strides: [1] : memref to memref<1024xf16, strided<[1], offset: ?>> + %35 = arith.truncf %34 : tensor<1024xf32> to tensor<1024xf16> + %extracted_slice = tensor.extract_slice %35[0] [%12] [1] : tensor<1024xf16> to tensor + %subview_11 = memref.subview %reinterpret_cast_10[0] [%12] [1] : memref<1024xf16, strided<[1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview_11 : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/angle/triton-ascend/angle_kernel.ttir b/python/test/ops/angle/triton-ascend/angle_kernel.ttir new file mode 100644 index 000000000..c1a8c23e5 --- /dev/null +++ b/python/test/ops/angle/triton-ascend/angle_kernel.ttir @@ -0,0 +1,102 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":17:0) +module { + tt.func public @angle_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":17:0), %arg1: !tt.ptr loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":17:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":17:0), %arg3: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":17:0)) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32> loc(#loc1) + %cst_0 = arith.constant dense<1.57079637> : tensor<1024xf32> loc(#loc1) + %cst_1 = arith.constant dense<-1.57079637> : tensor<1024xf32> loc(#loc1) + %cst_2 = arith.constant dense<3.14159274> : tensor<1024xf32> loc(#loc1) + %cst_3 = arith.constant dense<-3.14159274> : tensor<1024xf32> loc(#loc1) + %cst_4 = arith.constant dense<0.000000e+00> : tensor<1024xf16> loc(#loc1) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4) + %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5) + %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5) + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc6) + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6) + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc7) + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc7) + %9 = tt.load %8, %6, %cst_4 : tensor<1024x!tt.ptr> loc(#loc8) + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc9) + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc9) + %12 = tt.load %11, %6, %cst_4 : tensor<1024x!tt.ptr> loc(#loc10) + %13 = arith.extf %9 : tensor<1024xf16> to tensor<1024xf32> loc(#loc11) + %14 = arith.extf %12 : tensor<1024xf16> to tensor<1024xf32> loc(#loc12) + %15 = arith.cmpf oeq, %13, %cst : tensor<1024xf32> loc(#loc35) + %16 = arith.divf %14, %13 : tensor<1024xf32> loc(#loc36) + %17 = tt.extern_elementwise %16 {libname = "", libpath = "", pure = true, symbol = "__hmf_atanf"} : (tensor<1024xf32>) -> tensor<1024xf32> loc(#loc37) + %18 = arith.select %15, %cst, %17 : tensor<1024xi1>, tensor<1024xf32> loc(#loc38) + %19 = arith.cmpf ogt, %14, %cst : tensor<1024xf32> loc(#loc39) + %20 = arith.andi %15, %19 : tensor<1024xi1> loc(#loc40) + %21 = arith.select %20, %cst_0, %18 : tensor<1024xi1>, tensor<1024xf32> loc(#loc41) + %22 = arith.cmpf olt, %14, %cst : tensor<1024xf32> loc(#loc42) + %23 = arith.andi %15, %22 : tensor<1024xi1> loc(#loc43) + %24 = arith.select %23, %cst_1, %21 : tensor<1024xi1>, tensor<1024xf32> loc(#loc44) + %25 = arith.cmpf olt, %13, %cst : tensor<1024xf32> loc(#loc45) + %26 = arith.cmpf oge, %14, %cst : tensor<1024xf32> loc(#loc46) + %27 = arith.andi %25, %26 : tensor<1024xi1> loc(#loc47) + %28 = arith.select %27, %cst_2, %cst : tensor<1024xi1>, tensor<1024xf32> loc(#loc48) + %29 = arith.andi %25, %22 : tensor<1024xi1> loc(#loc49) + %30 = arith.select %29, %cst_3, %cst : tensor<1024xi1>, tensor<1024xf32> loc(#loc50) + %31 = arith.addf %24, %28 : tensor<1024xf32> loc(#loc51) + %32 = arith.addf %31, %30 : tensor<1024xf32> loc(#loc52) + %33 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> loc(#loc32) + %34 = tt.addptr %33, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> loc(#loc32) + %35 = arith.truncf %32 : tensor<1024xf32> to tensor<1024xf16> loc(#loc33) + tt.store %34, %35, %6 : tensor<1024x!tt.ptr> loc(#loc33) + tt.return loc(#loc34) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":24:24) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":25:24) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":26:41) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":26:28) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":27:21) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":28:27) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":28:19) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":29:27) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":29:19) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":31:17) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":31:38) +#loc13 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":131:27) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":35:38) +#loc15 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":131:61) +#loc16 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":131:45) +#loc17 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":131:35) +#loc18 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":132:38) +#loc19 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":132:34) +#loc20 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":132:51) +#loc21 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":133:38) +#loc22 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":133:34) +#loc23 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":133:52) +#loc24 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":135:29) +#loc25 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":135:40) +#loc26 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":135:35) +#loc27 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":135:48) +#loc28 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":136:35) +#loc29 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":136:48) +#loc30 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":137:19) +#loc31 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/triton_patch/language/standard.py":137:28) +#loc32 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":36:26) +#loc33 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":36:35) +#loc34 = loc("/home/zhengyang/git/flagtree/python/test/ops/angle/angle.py":36:4) +#loc35 = loc(callsite(#loc13 at #loc14)) +#loc36 = loc(callsite(#loc15 at #loc14)) +#loc37 = loc(callsite(#loc16 at #loc14)) +#loc38 = loc(callsite(#loc17 at #loc14)) +#loc39 = loc(callsite(#loc18 at #loc14)) +#loc40 = loc(callsite(#loc19 at #loc14)) +#loc41 = loc(callsite(#loc20 at #loc14)) +#loc42 = loc(callsite(#loc21 at #loc14)) +#loc43 = loc(callsite(#loc22 at #loc14)) +#loc44 = loc(callsite(#loc23 at #loc14)) +#loc45 = loc(callsite(#loc24 at #loc14)) +#loc46 = loc(callsite(#loc25 at #loc14)) +#loc47 = loc(callsite(#loc26 at #loc14)) +#loc48 = loc(callsite(#loc27 at #loc14)) +#loc49 = loc(callsite(#loc28 at #loc14)) +#loc50 = loc(callsite(#loc29 at #loc14)) +#loc51 = loc(callsite(#loc30 at #loc14)) +#loc52 = loc(callsite(#loc31 at #loc14)) diff --git a/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb.py b/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb.py new file mode 100644 index 000000000..250e74629 --- /dev/null +++ b/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb.py @@ -0,0 +1,332 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def apply_rotary_pos_emb_kernel( + oq_ptr, + ok_ptr, + q_ptr, # (n_tokens, q_heads, head_dim) + k_ptr, # (n_tokens, k_heads, head_dim) + cos_ptr, # (max_seq_len, dim // 2) + sin_ptr, # (max_seq_len, dim // 2) + pos_ptr, # (n_tokens, ) + q_stride_s, + q_stride_h, + q_stride_d, + k_stride_s, + k_stride_h, + k_stride_d, + oq_stride_s, + oq_stride_h, + oq_stride_d, + ok_stride_s, + ok_stride_h, + ok_stride_d, + p_stride_s, + cos_stride_s, + sin_stride_s, + seq_len, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + PADDED_HEAD_DIM: tl.constexpr, + ROTARY_INTERLEAVED: tl.constexpr, + MAX_POSITION_EMBEDDINGS: tl.constexpr, +): + s_id = tl.program_id(0) + + if pos_ptr is None: + pos_id = s_id % seq_len + else: + pos_ptr += s_id * p_stride_s + pos_id = tl.load(pos_ptr) + cos_ptr += pos_id * cos_stride_s + sin_ptr += pos_id * sin_stride_s + + # note: set TRITON_DEBUG=1 to enable this check + tl.device_assert(pos_id < MAX_POSITION_EMBEDDINGS, "position id out of bound") + + ordered_block = tl.arange(0, PADDED_HEAD_DIM) + mask = ordered_block < HEAD_DIM + if ROTARY_INTERLEAVED: + odd_mask = ordered_block % 2 == 0 + rotated_block = tl.where(odd_mask, ordered_block + 1, ordered_block - 1) + sin_cos_block = ordered_block // 2 + cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) + sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) + sin = tl.where(odd_mask, -sin, sin) + else: + rotated_block = (ordered_block + HEAD_DIM // 2) % HEAD_DIM + sin_cos_block = ordered_block % (HEAD_DIM // 2) + cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) + sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) + sin = tl.where(rotated_block < HEAD_DIM // 2, sin, -sin) + + oq_ptr += s_id * oq_stride_s + q_ptr += s_id * q_stride_s + + for off_h in range(0, NUM_Q_HEADS): + ordered_cols = off_h * q_stride_h + (ordered_block * q_stride_d) + rotated_cols = off_h * q_stride_h + (rotated_block * q_stride_d) + output_offs = off_h * oq_stride_h + (ordered_block * oq_stride_d) + + q = tl.load(q_ptr + ordered_cols, mask=mask, other=0.0) + rotated_q = tl.load(q_ptr + rotated_cols, mask=mask, other=0.0) + y = q * cos + rotated_q * sin + tl.store(oq_ptr + output_offs, y, mask=mask) + + ok_ptr += s_id * ok_stride_s + k_ptr += s_id * k_stride_s + + for off_h in range(0, NUM_K_HEADS): + ordered_cols = off_h * k_stride_h + (ordered_block * k_stride_d) + rotated_cols = off_h * k_stride_h + (rotated_block * k_stride_d) + output_offs = off_h * ok_stride_h + (ordered_block * ok_stride_d) + + k = tl.load(k_ptr + ordered_cols, mask=mask, other=0.0) + rotated_k = tl.load(k_ptr + rotated_cols, mask=mask, other=0.0) + y = k * cos + rotated_k * sin + tl.store(ok_ptr + output_offs, y, mask=mask) + + +def apply_rotary_pos_emb( + q, + k, + cos, + sin, + position_ids = None, + rotary_interleaved: bool = False, + inplace: bool = False, +): + """ + Apply rotary position embedding to q and k + + Args: + q: (*, q_heads, head_dim) + k: (*, k_heads, head_dim) + cos: (max_seq_len, head_dim // 2) + sin: (max_seq_len, head_dim // 2) + position_ids: (*, ), optional, position ids for each token + rotary_interleaved: whether the head_dim is rotated in an interleaved way + + Returns: + q_embed: (*, q_heads, head_dim) + k_embed: (*, k_heads, head_dim) + """ + assert ( + k.shape[-1] == q.shape[-1] + ), f"q and k must have the same last dimension, got {q.shape} and {k.shape}" + assert ( + cos.shape[-1] == sin.shape[-1] + ), f"cos and sin must have the same last dimension, got {cos.shape} and {sin.shape}" + assert ( + cos.shape[-1] * 2 == q.shape[-1] + ), f"cos/sin dim must be half of q/k dim, got {cos.shape} and {q.shape}" + assert cos.stride(-1) == 1, "cos must be contiguous at the last dimension" + assert sin.stride(-1) == 1, "sin must be contiguous at the last dimension" + + q_shape = q.shape + k_shape = k.shape + + assert ( + q.shape[:-2] == k.shape[:-2] + ), f"q and k must have the same length, got {q.shape[:-2]} and {k.shape[:-2]}" + if position_ids is None: + assert ( + len(q.shape) == 4 + ), f"q must have 4 dimensions if position_ids is not provided, got {q.shape}" + seq_len = q.shape[-3] + else: + assert ( + position_ids.shape == q.shape[:-2] + ), f"position_ids must have the same length as q, got {position_ids.shape} and {q.shape[:-2]}" + + position_ids = position_ids.view(-1) + seq_len = None + + q = q.view(-1, q.shape[-2], q.shape[-1]) + k = k.view(-1, k.shape[-2], k.shape[-1]) + + n_tokens, q_heads, head_dim = q.shape + + # The block size must be the next power of two, sometimes we need to pad it. + padded_head_dim = max(triton.next_power_of_2(head_dim), 16) + + if not inplace: + q_embed = torch.empty_like(q) + k_embed = torch.empty_like(k) + + grid = (n_tokens,) + with torch_device_fn.device(q_embed.device): + apply_rotary_pos_emb_kernel[grid]( + q_embed, + k_embed, + q, + k, + cos, + sin, + position_ids, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + q_embed.stride(0), + q_embed.stride(1), + q_embed.stride(2), + k_embed.stride(0), + k_embed.stride(1), + k_embed.stride(2), + position_ids.stride(0) if position_ids is not None else 0, + cos.stride(0), + sin.stride(0), + seq_len, + q.shape[-2], + k.shape[-2], + head_dim, + padded_head_dim, + rotary_interleaved, + MAX_POSITION_EMBEDDINGS=cos.shape[0], + ) + q_embed = q_embed.view(q_shape) + k_embed = k_embed.view(k_shape) + return q_embed, k_embed + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +# for test +def rotate_interleave(x): + """Rotates interleave the hidden dims of the input.""" + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def torch_apply_rotary_pos_emb( + q, + k, + cos, + sin, + position_ids = None, + rotary_interleaved: bool = False, +): + q = q.float() + k = k.float() + if position_ids is None: + cos = cos[None, : q.size(-3), None, :] + sin = sin[None, : q.size(-3), None, :] + else: + cos = cos[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] + sin = sin[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] + if rotary_interleaved: + cos = torch.repeat_interleave(cos, 2, dim=-1) # [bs, seq_len, 1, dim] + sin = torch.repeat_interleave(sin, 2, dim=-1) # [bs, seq_len, 1, dim] + rotate_fn = rotate_interleave + else: + cos = torch.cat([cos, cos], dim=-1) # [bs, seq_len, 1, dim] + sin = torch.cat([sin, sin], dim=-1) # [bs, seq_len, 1, dim] + rotate_fn = rotate_half + + q_embed = (q * cos) + (rotate_fn(q) * sin) + k_embed = (k * cos) + (rotate_fn(k) * sin) + + return q_embed, k_embed + + +def get_rope_cos_sin(max_seq_len, dim, dtype, base=10000, device=device): + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + t = torch.arange(max_seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +if __name__ == "__main__": + # param + batch_size = 2 + max_seq_len = 16 + q_heads = 6 + k_heads = 2 + head_dim = 8 + rotary_interleaved = True + has_pos_id = True + dtype = torch.float32 + seq_len = 12 + + # inp + q = torch.randn( + (batch_size, seq_len, q_heads, head_dim), dtype=dtype, device=device + ) + k = torch.randn( + (batch_size, seq_len, k_heads, head_dim), dtype=dtype, device=device + ) + position_ids = torch.randint( + 0, max_seq_len, (batch_size, seq_len), device=device + ) + cos, sin = get_rope_cos_sin(max_seq_len, head_dim, dtype, device=device) + ref_q = q.cpu() + ref_k = k.cpu() + ref_cos = cos.cpu() + ref_sin = sin.cpu() + ref_position_ids = position_ids.cpu() + + # op + q_embed_ref, k_embed_ref = torch_apply_rotary_pos_emb( + q=ref_q, + k=ref_k, + cos=ref_cos, + sin=ref_sin, + position_ids=ref_position_ids if has_pos_id else None, + rotary_interleaved=rotary_interleaved, + ) + q_embed_out, k_embed_out = apply_rotary_pos_emb( + q=q, + k=k, + cos=cos, + sin=sin, + position_ids=position_ids if has_pos_id else None, + rotary_interleaved=rotary_interleaved, + ) + check("q_embed", q_embed_ref, q_embed_out) + check("k_embed", k_embed_ref, k_embed_out) diff --git a/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py b/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py new file mode 100644 index 000000000..81805eaf7 --- /dev/null +++ b/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py @@ -0,0 +1,406 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def rotary_embedding_rw_kernel( + state_out, + state, + cos, + sin, + stride_state_n, + stride_state_h, + stride_state_d, + stride_cos_n, + stride_cos_d, + num_tokens, + num_heads, + token_range, + head_range, + dim_range_x, + dim_range_y, + rotary_interleaved: tl.constexpr, +): + state_x_offset = ( + token_range[:, None, None] * stride_state_n + + head_range[None, :, None] * stride_state_h + + dim_range_x[None, None, :] * stride_state_d + ) + state_y_offset = ( + token_range[:, None, None] * stride_state_n + + head_range[None, :, None] * stride_state_h + + dim_range_y[None, None, :] * stride_state_d + ) + + cos_sim_offset = ( + token_range[:, None, None] * stride_cos_n + + dim_range_x[None, None, :] * stride_cos_d + ) + if rotary_interleaved: + sin_sim_offset = ( + token_range[:, None, None] * stride_cos_n + + dim_range_y[None, None, :] * stride_cos_d + ) + else: + sin_sim_offset = cos_sim_offset + + state_x = tl.load( + state + state_x_offset, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + other=0.0, + ) + state_y = tl.load( + state + state_y_offset, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + other=0.0, + ) + + cos_loaded = tl.load( + cos + cos_sim_offset, + mask=token_range[:, None, None] < num_tokens, + other=0.0, + ).to(tl.float32) + sin_loaded = tl.load( + sin + sin_sim_offset, + mask=token_range[:, None, None] < num_tokens, + other=0.0, + ).to(tl.float32) + + out_x = state_x * cos_loaded - state_y * sin_loaded + out_y = state_x * sin_loaded + state_y * cos_loaded + + tl.store( + state_out + state_x_offset, + out_x, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + ) + tl.store( + state_out + state_y_offset, + out_y, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + ) + + +@triton.jit +def rotary_embedding_siso_kernel( + state_out, # [num_tokens, head_num, head_dim] + state, # [num_tokens, head_num, head_dim] + cos, # [num_tokens, 1, head_dim // 2] + sin, # [num_tokens, 1, head_dim // 2] + stride_state_n, + stride_state_h, + stride_state_d, + stride_cos_n, + stride_cos_d, + num_tokens, + num_heads, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, + rotary_interleaved: tl.constexpr, +): + token_index = tl.program_id(0) + token_range = token_index * BLOCK_N + tl.arange(0, BLOCK_N) + head_index = tl.program_id(1) + head_range = head_index * BLOCK_H + tl.arange(0, BLOCK_H) + + if rotary_interleaved: + for d in range(0, BLOCK_D // 2): + dim_range_x = d * 2 + dim_range_y = d * 2 + 1 + + rotary_embedding_rw_kernel( + state_out, + state, + cos, + sin, + stride_state_n, + stride_state_h, + stride_state_d, + stride_cos_n, + stride_cos_d, + num_tokens, + num_heads, + token_range, + head_range, + dim_range_x, + dim_range_y, + rotary_interleaved, + ) + else: + dim_range_x = tl.arange(0, BLOCK_D // 2) + dim_range_y = tl.arange(BLOCK_D // 2, BLOCK_D) + rotary_embedding_rw_kernel( + state_out, + state, + cos, + sin, + stride_state_n, + stride_state_h, + stride_state_d, + stride_cos_n, + stride_cos_d, + num_tokens, + num_heads, + token_range, + head_range, + dim_range_x, + dim_range_y, + rotary_interleaved, + ) + + +def apply_rotary_pos_emb( + q, + k, + cos, + sin, + position_ids = None, + rotary_interleaved: bool = False, + inplace: bool = False, +): + """ + Apply rotary position embedding to q and k + + Args: + q: (*, q_heads, head_dim) + k: (*, k_heads, head_dim) + cos: (max_seq_len, head_dim // 2) + sin: (max_seq_len, head_dim // 2) + position_ids: (*, ), optional, position ids for each token + rotary_interleaved: whether the head_dim is rotated in an interleaved way + + Returns: + q_embed: (*, q_heads, head_dim) + k_embed: (*, k_heads, head_dim) + """ + assert ( + k.shape[-1] == q.shape[-1] + ), f"q and k must have the same last dimension, got {q.shape} and {k.shape}" + assert ( + cos.shape[-1] == sin.shape[-1] + ), f"cos and sin must have the same last dimension, got {cos.shape} and {sin.shape}" + assert ( + cos.shape[-1] * 2 == q.shape[-1] + ), f"cos/sin dim must be half of q/k dim, got {cos.shape} and {q.shape}" + assert cos.stride(-1) == 1, "cos must be contiguous at the last dimension" + assert sin.stride(-1) == 1, "sin must be contiguous at the last dimension" + + q_shape = q.shape + k_shape = k.shape + + assert ( + q.shape[:-2] == k.shape[:-2] + ), f"q and k must have the same length, got {q.shape[:-2]} and {k.shape[:-2]}" + if position_ids is None: + assert ( + len(q.shape) == 4 + ), f"q must have 4 dimensions if position_ids is not provided, got {q.shape}" + seq_len = q.shape[-3] + else: + assert ( + position_ids.shape == q.shape[:-2] + ), f"position_ids must have the same length as q, got {position_ids.shape} and {q.shape[:-2]}" + + position_ids = position_ids.view(-1) + seq_len = None + + q = q.view(-1, q.shape[-2], q.shape[-1]) + k = k.view(-1, k.shape[-2], k.shape[-1]) + + q_embed = torch.empty_like(q) + k_embed = torch.empty_like(k) + + def torch_rotary_embedding(state_out, state, cos, sin): + num_tokens = state.shape[0] + num_heads = state.shape[1] + head_dim = state.shape[-1] + + BLOCK_N = 8 + BLOCK_H = 4 + grid = ( + triton.cdiv(num_tokens, BLOCK_N), + triton.cdiv(num_heads, BLOCK_H), + ) + with torch_device_fn.device(state_out.device): + if True: + if position_ids is None: + cos = cos[: q_shape[-3], None, :] + sin = sin[: q_shape[-3], None, :] + else: + cos = cos[position_ids, None, :] + sin = sin[position_ids, None, :] + + if rotary_interleaved: + cos = torch.repeat_interleave(cos, 2, dim=-1) + sin = torch.repeat_interleave(sin, 2, dim=-1) + orig_cos = cos + orig_sin = sin + for _ in range(q_shape[0] - 1): + cos = torch.cat((cos, orig_cos), dim=0) + sin = torch.cat((sin, orig_sin), dim=0) + rotary_embedding_siso_kernel[grid]( + state_out, + state, + cos, + sin, + state.stride(0), + state.stride(1), + state.stride(2), + cos.stride(0), + cos.stride(2), + num_tokens, + num_heads, + BLOCK_N=BLOCK_N, + BLOCK_H=BLOCK_H, + BLOCK_D=head_dim, + rotary_interleaved=rotary_interleaved, + ) + + torch_rotary_embedding(q_embed, q, cos, sin) + torch_rotary_embedding(k_embed, k, cos, sin) + + q_embed = q_embed.view(q_shape) + k_embed = k_embed.view(k_shape) + return q_embed, k_embed + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +# for test +def rotate_interleave(x): + """Rotates interleave the hidden dims of the input.""" + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def torch_apply_rotary_pos_emb( + q, + k, + cos, + sin, + position_ids = None, + rotary_interleaved: bool = False, +): + q = q.float() + k = k.float() + if position_ids is None: + cos = cos[None, : q.size(-3), None, :] + sin = sin[None, : q.size(-3), None, :] + else: + cos = cos[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] + sin = sin[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] + if rotary_interleaved: + cos = torch.repeat_interleave(cos, 2, dim=-1) # [bs, seq_len, 1, dim] + sin = torch.repeat_interleave(sin, 2, dim=-1) # [bs, seq_len, 1, dim] + rotate_fn = rotate_interleave + else: + cos = torch.cat([cos, cos], dim=-1) # [bs, seq_len, 1, dim] + sin = torch.cat([sin, sin], dim=-1) # [bs, seq_len, 1, dim] + rotate_fn = rotate_half + + q_embed = (q * cos) + (rotate_fn(q) * sin) + k_embed = (k * cos) + (rotate_fn(k) * sin) + + return q_embed, k_embed + + +def get_rope_cos_sin(max_seq_len, dim, dtype, base=10000, device=device): + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + t = torch.arange(max_seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +if __name__ == "__main__": + # param + batch_size = 2 + max_seq_len = 16 + q_heads = 6 + k_heads = 2 + head_dim = 8 + rotary_interleaved = True + has_pos_id = True + dtype = torch.float32 + seq_len = 12 + + # inp + q = torch.randn( + (batch_size, seq_len, q_heads, head_dim), dtype=dtype, device=device + ) + k = torch.randn( + (batch_size, seq_len, k_heads, head_dim), dtype=dtype, device=device + ) + position_ids = torch.randint( + 0, max_seq_len, (batch_size, seq_len), device=device + ) + cos, sin = get_rope_cos_sin(max_seq_len, head_dim, dtype, device=device) + ref_q = q.cpu() + ref_k = k.cpu() + ref_cos = cos.cpu() + ref_sin = sin.cpu() + ref_position_ids = position_ids.cpu() + + # op + q_embed_ref, k_embed_ref = torch_apply_rotary_pos_emb( + q=ref_q, + k=ref_k, + cos=ref_cos, + sin=ref_sin, + position_ids=ref_position_ids if has_pos_id else None, + rotary_interleaved=rotary_interleaved, + ) + q_embed_out, k_embed_out = apply_rotary_pos_emb( + q=q, + k=k, + cos=cos, + sin=sin, + position_ids=position_ids if has_pos_id else None, + rotary_interleaved=rotary_interleaved, + ) + check("q_embed", q_embed_ref, q_embed_out) + check("k_embed", k_embed_ref, k_embed_out) diff --git a/python/test/ops/apply_rotary_pos_emb/triton-ascend/rotary_embedding_siso_kernel.ttadapter b/python/test/ops/apply_rotary_pos_emb/triton-ascend/rotary_embedding_siso_kernel.ttadapter new file mode 100644 index 000000000..887dfc56b --- /dev/null +++ b/python/test/ops/apply_rotary_pos_emb/triton-ascend/rotary_embedding_siso_kernel.ttadapter @@ -0,0 +1,107 @@ +module { + func.func @rotary_embedding_siso_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg5: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %cst = arith.constant 0.000000e+00 : f32 + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = arith.muli %arg14, %c8_i32 : i32 + %1 = arith.muli %arg15, %c4_i32 : i32 + scf.for %arg17 = %c0_i32 to %c4_i32 step %c1_i32 : i32 { + %2 = arith.muli %arg17, %c2_i32 : i32 + %3 = arith.addi %2, %c1_i32 : i32 + %4 = arith.index_cast %0 : i32 to index + %5 = arith.index_cast %arg6 : i32 to index + %6 = arith.muli %4, %5 : index + %7 = arith.index_cast %1 : i32 to index + %8 = arith.index_cast %arg7 : i32 to index + %9 = arith.muli %7, %8 : index + %10 = arith.index_cast %2 : i32 to index + %11 = arith.addi %6, %10 : index + %12 = arith.addi %11, %9 : index + %reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%12], sizes: [8, 4, 1], strides: [%5, %8, 1] : memref to memref<8x4x1xf32, strided<[?, ?, 1], offset: ?>> + %alloc = memref.alloc() : memref<8x4x1xf32> + %13 = arith.addi %4, %c8 : index + %14 = arith.index_cast %arg9 : i32 to index + %15 = arith.maxsi %4, %14 : index + %16 = arith.minsi %13, %15 : index + %17 = arith.subi %16, %4 : index + %18 = arith.addi %7, %c4 : index + %19 = arith.index_cast %arg10 : i32 to index + %20 = arith.maxsi %7, %19 : index + %21 = arith.minsi %18, %20 : index + %22 = arith.subi %21, %7 : index + %23 = arith.minsi %17, %c8 : index + %24 = arith.minsi %22, %c4 : index + %25 = arith.cmpi slt, %23, %c8 : index + %26 = arith.cmpi slt, %24, %c4 : index + %27 = arith.ori %25, %26 : i1 + scf.if %27 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<8x4x1xf32>) + } + %subview = memref.subview %reinterpret_cast[0, 0, 0] [%23, %24, 1] [1, 1, 1] : memref<8x4x1xf32, strided<[?, ?, 1], offset: ?>> to memref> + %subview_0 = memref.subview %alloc[0, 0, 0] [%23, %24, 1] [1, 1, 1] : memref<8x4x1xf32> to memref> + memref.copy %subview, %subview_0 : memref> to memref> + %28 = bufferization.to_tensor %alloc restrict writable : memref<8x4x1xf32> + %29 = arith.index_cast %3 : i32 to index + %30 = arith.addi %6, %29 : index + %31 = arith.addi %30, %9 : index + %reinterpret_cast_1 = memref.reinterpret_cast %arg3 to offset: [%31], sizes: [8, 4, 1], strides: [%5, %8, 1] : memref to memref<8x4x1xf32, strided<[?, ?, 1], offset: ?>> + %alloc_2 = memref.alloc() : memref<8x4x1xf32> + scf.if %27 { + linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<8x4x1xf32>) + } + %subview_3 = memref.subview %reinterpret_cast_1[0, 0, 0] [%23, %24, 1] [1, 1, 1] : memref<8x4x1xf32, strided<[?, ?, 1], offset: ?>> to memref> + %subview_4 = memref.subview %alloc_2[0, 0, 0] [%23, %24, 1] [1, 1, 1] : memref<8x4x1xf32> to memref> + memref.copy %subview_3, %subview_4 : memref> to memref> + %32 = bufferization.to_tensor %alloc_2 restrict writable : memref<8x4x1xf32> + %33 = arith.index_cast %arg8 : i32 to index + %34 = arith.muli %4, %33 : index + %35 = arith.addi %34, %10 : index + %reinterpret_cast_5 = memref.reinterpret_cast %arg4 to offset: [%35], sizes: [8, 1, 1], strides: [%33, 1, 1] : memref to memref<8x1x1xf32, strided<[?, 1, 1], offset: ?>> + %alloc_6 = memref.alloc() : memref<8x1x1xf32> + %36 = arith.cmpi slt, %17, %c8 : index + scf.if %36 { + linalg.fill ins(%cst : f32) outs(%alloc_6 : memref<8x1x1xf32>) + } + %subview_7 = memref.subview %reinterpret_cast_5[0, 0, 0] [%17, 1, 1] [1, 1, 1] : memref<8x1x1xf32, strided<[?, 1, 1], offset: ?>> to memref> + %subview_8 = memref.subview %alloc_6[0, 0, 0] [%17, 1, 1] [1, 1, 1] : memref<8x1x1xf32> to memref> + memref.copy %subview_7, %subview_8 : memref> to memref> + %37 = bufferization.to_tensor %alloc_6 restrict writable : memref<8x1x1xf32> + %38 = arith.addi %34, %29 : index + %reinterpret_cast_9 = memref.reinterpret_cast %arg5 to offset: [%38], sizes: [8, 1, 1], strides: [%33, 1, 1] : memref to memref<8x1x1xf32, strided<[?, 1, 1], offset: ?>> + %alloc_10 = memref.alloc() : memref<8x1x1xf32> + scf.if %36 { + linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<8x1x1xf32>) + } + %subview_11 = memref.subview %reinterpret_cast_9[0, 0, 0] [%17, 1, 1] [1, 1, 1] : memref<8x1x1xf32, strided<[?, 1, 1], offset: ?>> to memref> + %subview_12 = memref.subview %alloc_10[0, 0, 0] [%17, 1, 1] [1, 1, 1] : memref<8x1x1xf32> to memref> + memref.copy %subview_11, %subview_12 : memref> to memref> + %39 = bufferization.to_tensor %alloc_10 restrict writable : memref<8x1x1xf32> + %40 = tensor.empty() : tensor<8x4x1xf32> + %collapsed = tensor.collapse_shape %37 [[0], [1, 2]] : tensor<8x1x1xf32> into tensor<8x1xf32> + %broadcasted = linalg.broadcast ins(%collapsed : tensor<8x1xf32>) outs(%40 : tensor<8x4x1xf32>) dimensions = [1] + %41 = arith.mulf %28, %broadcasted : tensor<8x4x1xf32> + %collapsed_13 = tensor.collapse_shape %39 [[0], [1, 2]] : tensor<8x1x1xf32> into tensor<8x1xf32> + %broadcasted_14 = linalg.broadcast ins(%collapsed_13 : tensor<8x1xf32>) outs(%40 : tensor<8x4x1xf32>) dimensions = [1] + %42 = arith.mulf %32, %broadcasted_14 : tensor<8x4x1xf32> + %43 = arith.subf %41, %42 : tensor<8x4x1xf32> + %44 = arith.mulf %28, %broadcasted_14 : tensor<8x4x1xf32> + %45 = arith.mulf %32, %broadcasted : tensor<8x4x1xf32> + %46 = arith.addf %44, %45 : tensor<8x4x1xf32> + %reinterpret_cast_15 = memref.reinterpret_cast %arg2 to offset: [%12], sizes: [8, 4, 1], strides: [%5, %8, 1] : memref to memref<8x4x1xf32, strided<[?, ?, 1], offset: ?>> + %extracted_slice = tensor.extract_slice %43[0, 0, 0] [%23, %24, 1] [1, 1, 1] : tensor<8x4x1xf32> to tensor + %subview_16 = memref.subview %reinterpret_cast_15[0, 0, 0] [%23, %24, 1] [1, 1, 1] : memref<8x4x1xf32, strided<[?, ?, 1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview_16 : (tensor, memref>) -> () + %reinterpret_cast_17 = memref.reinterpret_cast %arg2 to offset: [%31], sizes: [8, 4, 1], strides: [%5, %8, 1] : memref to memref<8x4x1xf32, strided<[?, ?, 1], offset: ?>> + %extracted_slice_18 = tensor.extract_slice %46[0, 0, 0] [%23, %24, 1] [1, 1, 1] : tensor<8x4x1xf32> to tensor + %subview_19 = memref.subview %reinterpret_cast_17[0, 0, 0] [%23, %24, 1] [1, 1, 1] : memref<8x4x1xf32, strided<[?, ?, 1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice_18 in writable %subview_19 : (tensor, memref>) -> () + } + return + } +} + diff --git a/python/test/ops/apply_rotary_pos_emb/triton-ascend/rotary_embedding_siso_kernel.ttir b/python/test/ops/apply_rotary_pos_emb/triton-ascend/rotary_embedding_siso_kernel.ttir new file mode 100644 index 000000000..04204b607 --- /dev/null +++ b/python/test/ops/apply_rotary_pos_emb/triton-ascend/rotary_embedding_siso_kernel.ttir @@ -0,0 +1,155 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0) +module { + tt.func public @rotary_embedding_siso_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0), %arg3: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0), %arg5: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0), %arg6: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0), %arg7: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0), %arg8: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":100:0)) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<8x4x1xf32> loc(#loc1) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<8x1x1xf32> loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c8_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc4) + %3 = tt.splat %1 : i32 -> tensor<8xi32> loc(#loc5) + %4 = arith.addi %3, %2 : tensor<8xi32> loc(#loc5) + %5 = tt.get_program_id y : i32 loc(#loc6) + %6 = arith.muli %5, %c4_i32 : i32 loc(#loc7) + %7 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc8) + %8 = tt.splat %6 : i32 -> tensor<4xi32> loc(#loc9) + %9 = arith.addi %8, %7 : tensor<4xi32> loc(#loc9) + %10 = tt.expand_dims %4 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc46) + %11 = tt.expand_dims %10 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> loc(#loc46) + %12 = tt.splat %arg4 : i32 -> tensor<8x1x1xi32> loc(#loc47) + %13 = arith.muli %11, %12 : tensor<8x1x1xi32> loc(#loc47) + %14 = tt.expand_dims %9 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> loc(#loc48) + %15 = tt.expand_dims %14 {axis = 2 : i32} : tensor<1x4xi32> -> tensor<1x4x1xi32> loc(#loc48) + %16 = tt.splat %arg5 : i32 -> tensor<1x4x1xi32> loc(#loc49) + %17 = arith.muli %15, %16 : tensor<1x4x1xi32> loc(#loc49) + %18 = tt.broadcast %13 : tensor<8x1x1xi32> -> tensor<8x4x1xi32> loc(#loc50) + %19 = tt.broadcast %17 : tensor<1x4x1xi32> -> tensor<8x4x1xi32> loc(#loc50) + %20 = arith.addi %18, %19 : tensor<8x4x1xi32> loc(#loc50) + %21 = tt.splat %arg6 : i32 -> tensor<8x1x1xi32> loc(#loc51) + %22 = arith.muli %11, %21 : tensor<8x1x1xi32> loc(#loc51) + %23 = tt.splat %arg7 : i32 -> tensor<8x1x1xi32> loc(#loc52) + %24 = arith.cmpi slt, %11, %23 : tensor<8x1x1xi32> loc(#loc52) + %25 = tt.splat %arg8 : i32 -> tensor<1x4x1xi32> loc(#loc53) + %26 = arith.cmpi slt, %15, %25 : tensor<1x4x1xi32> loc(#loc53) + %27 = tt.broadcast %24 : tensor<8x1x1xi1> -> tensor<8x4x1xi1> loc(#loc54) + %28 = tt.broadcast %26 : tensor<1x4x1xi1> -> tensor<8x4x1xi1> loc(#loc54) + %29 = arith.andi %27, %28 : tensor<8x4x1xi1> loc(#loc54) + %30 = tt.splat %arg1 : !tt.ptr -> tensor<8x4x1x!tt.ptr> loc(#loc55) + %31 = tt.splat %arg2 : !tt.ptr -> tensor<8x1x1x!tt.ptr> loc(#loc56) + %32 = tt.splat %arg3 : !tt.ptr -> tensor<8x1x1x!tt.ptr> loc(#loc57) + %33 = tt.splat %arg0 : !tt.ptr -> tensor<8x4x1x!tt.ptr> loc(#loc58) + scf.for %arg9 = %c0_i32 to %c4_i32 step %c1_i32 : i32 { + %34 = arith.muli %arg9, %c2_i32 : i32 loc(#loc25) + %35 = arith.addi %34, %c1_i32 : i32 loc(#loc26) + %36 = tt.splat %34 : i32 -> tensor<8x4x1xi32> loc(#loc59) + %37 = arith.addi %20, %36 : tensor<8x4x1xi32> loc(#loc59) + %38 = tt.splat %35 : i32 -> tensor<8x4x1xi32> loc(#loc60) + %39 = arith.addi %20, %38 : tensor<8x4x1xi32> loc(#loc60) + %40 = tt.splat %34 : i32 -> tensor<8x1x1xi32> loc(#loc61) + %41 = arith.addi %22, %40 : tensor<8x1x1xi32> loc(#loc61) + %42 = tt.splat %35 : i32 -> tensor<8x1x1xi32> loc(#loc62) + %43 = arith.addi %22, %42 : tensor<8x1x1xi32> loc(#loc62) + %44 = tt.addptr %30, %37 : tensor<8x4x1x!tt.ptr>, tensor<8x4x1xi32> loc(#loc55) + %45 = tt.load %44, %29, %cst : tensor<8x4x1x!tt.ptr> loc(#loc63) + %46 = tt.addptr %30, %39 : tensor<8x4x1x!tt.ptr>, tensor<8x4x1xi32> loc(#loc64) + %47 = tt.load %46, %29, %cst : tensor<8x4x1x!tt.ptr> loc(#loc65) + %48 = tt.addptr %31, %41 : tensor<8x1x1x!tt.ptr>, tensor<8x1x1xi32> loc(#loc56) + %49 = tt.load %48, %24, %cst_0 : tensor<8x1x1x!tt.ptr> loc(#loc66) + %50 = tt.addptr %32, %43 : tensor<8x1x1x!tt.ptr>, tensor<8x1x1xi32> loc(#loc57) + %51 = tt.load %50, %24, %cst_0 : tensor<8x1x1x!tt.ptr> loc(#loc67) + %52 = tt.broadcast %49 : tensor<8x1x1xf32> -> tensor<8x4x1xf32> loc(#loc68) + %53 = arith.mulf %45, %52 : tensor<8x4x1xf32> loc(#loc68) + %54 = tt.broadcast %51 : tensor<8x1x1xf32> -> tensor<8x4x1xf32> loc(#loc69) + %55 = arith.mulf %47, %54 : tensor<8x4x1xf32> loc(#loc69) + %56 = arith.subf %53, %55 : tensor<8x4x1xf32> loc(#loc70) + %57 = arith.mulf %45, %54 : tensor<8x4x1xf32> loc(#loc71) + %58 = arith.mulf %47, %52 : tensor<8x4x1xf32> loc(#loc72) + %59 = arith.addf %57, %58 : tensor<8x4x1xf32> loc(#loc73) + %60 = tt.addptr %33, %37 : tensor<8x4x1x!tt.ptr>, tensor<8x4x1xi32> loc(#loc58) + tt.store %60, %56, %29 : tensor<8x4x1x!tt.ptr> loc(#loc74) + %61 = tt.addptr %33, %39 : tensor<8x4x1x!tt.ptr>, tensor<8x4x1xi32> loc(#loc75) + tt.store %61, %59, %29 : tensor<8x4x1x!tt.ptr> loc(#loc76) + } loc(#loc24) + tt.return loc(#loc45) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":117:32) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":118:32) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":118:55) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":118:42) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":119:31) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":120:30) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":120:53) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":120:40) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":36:20) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":143:16) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":36:37) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":37:21) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":37:38) +#loc15 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":37:10) +#loc16 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":47:37) +#loc17 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":60:43) +#loc18 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":61:39) +#loc19 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":61:11) +#loc20 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":59:16) +#loc21 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":72:14) +#loc22 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":77:14) +#loc23 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":86:20) +#loc24 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":123:26) +#loc25 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":124:30) +#loc26 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":125:34) +#loc27 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":38:10) +#loc28 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":43:10) +#loc29 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":48:10) +#loc30 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":53:14) +#loc31 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":59:8) +#loc32 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":65:16) +#loc33 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":65:8) +#loc34 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":72:8) +#loc35 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":77:8) +#loc36 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":82:22) +#loc37 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":82:45) +#loc38 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":82:35) +#loc39 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":83:22) +#loc40 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":83:45) +#loc41 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":83:35) +#loc42 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":87:8) +#loc43 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":92:20) +#loc44 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":93:8) +#loc45 = loc("/home/zhengyang/git/flagtree/python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py":122:4) +#loc46 = loc(callsite(#loc10 at #loc11)) +#loc47 = loc(callsite(#loc12 at #loc11)) +#loc48 = loc(callsite(#loc13 at #loc11)) +#loc49 = loc(callsite(#loc14 at #loc11)) +#loc50 = loc(callsite(#loc15 at #loc11)) +#loc51 = loc(callsite(#loc16 at #loc11)) +#loc52 = loc(callsite(#loc17 at #loc11)) +#loc53 = loc(callsite(#loc18 at #loc11)) +#loc54 = loc(callsite(#loc19 at #loc11)) +#loc55 = loc(callsite(#loc20 at #loc11)) +#loc56 = loc(callsite(#loc21 at #loc11)) +#loc57 = loc(callsite(#loc22 at #loc11)) +#loc58 = loc(callsite(#loc23 at #loc11)) +#loc59 = loc(callsite(#loc27 at #loc11)) +#loc60 = loc(callsite(#loc28 at #loc11)) +#loc61 = loc(callsite(#loc29 at #loc11)) +#loc62 = loc(callsite(#loc30 at #loc11)) +#loc63 = loc(callsite(#loc31 at #loc11)) +#loc64 = loc(callsite(#loc32 at #loc11)) +#loc65 = loc(callsite(#loc33 at #loc11)) +#loc66 = loc(callsite(#loc34 at #loc11)) +#loc67 = loc(callsite(#loc35 at #loc11)) +#loc68 = loc(callsite(#loc36 at #loc11)) +#loc69 = loc(callsite(#loc37 at #loc11)) +#loc70 = loc(callsite(#loc38 at #loc11)) +#loc71 = loc(callsite(#loc39 at #loc11)) +#loc72 = loc(callsite(#loc40 at #loc11)) +#loc73 = loc(callsite(#loc41 at #loc11)) +#loc74 = loc(callsite(#loc42 at #loc11)) +#loc75 = loc(callsite(#loc43 at #loc11)) +#loc76 = loc(callsite(#loc44 at #loc11)) diff --git a/python/test/ops/argmin/argmin.py b/python/test/ops/argmin/argmin.py new file mode 100644 index 000000000..710fb3faa --- /dev/null +++ b/python/test/ops/argmin/argmin.py @@ -0,0 +1,153 @@ +import math + +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def get_dtype_max(dtype: tl.constexpr): + """get a value which is greater that all other values of that dtype""" + # extract the tl.dtype from tl.constexpr so as to use its methods + dtype_ = dtype.value + if dtype_.is_floating(): + value: tl.constexpr = float("inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2 ** (width - 1) - 1 + return value + if dtype_.is_int_unsigned(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2**width - 1 + return value + + +@triton.jit +def argmin_kernel( + inp, + out_index, + M, + N, + K, + BLOCK_M: tl.constexpr = 8, + BLOCK_N: tl.constexpr = 16, +): + # set offset + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + dtype = inp.type.element_ty + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + max_value = get_dtype_max(dtype) + min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) + argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) + for start_n in range(0, N, BLOCK_N): + n_offset = start_n + tl.arange(0, BLOCK_N) + offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k + mask = m_offset[:, None] < M and n_offset[None, :] < N + inp_ptrs = inp + offset + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) + # tl.bfloat is promoted to tl.float32 by tl.min + local_min, local_argmin = tl.min( + inp_vals, 1, return_indices=True, return_indices_tie_break_left=True + ) + # if return indices is not supported, call a tl.argmin in addition + # local_argmin = tl.argmin(inp_vals, 1) + update = local_min < min_values + min_values = tl.where(update, local_min, min_values) + argmin_values = tl.where(update, start_n + local_argmin, argmin_values) + + offset_index = m_offset * K + pid_k + out_index_ptrs = out_index + offset_index + mask1 = m_offset < M + tl.store(out_index_ptrs, argmin_values, mask=mask1) + + +def argmin(inp, dim=None, keepdim=False, *, dtype=None): + if dim is not None: + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + shape = inp.shape + dim = dim % inp.ndim + N = shape[dim] + M = math.prod(shape[:dim]) + K = inp.numel() // M // N + + inp = inp.contiguous() + + shape_list = list(shape) + shape_list[dim] = 1 + out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) + if not keepdim: + out_index = torch.squeeze(out_index, dim) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + K, + ) + with torch_device_fn.device(inp.device): + argmin_kernel[grid]( + inp, + out_index, + M, + N, + K, + ) + + return out_index + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + shape = (1, 32) + dim = 1 + keepdim = True + dtype = torch.float32 + + # inp + inp = torch.randn(shape, dtype=dtype, device=device) + ref_inp = inp.cpu() + + # op + ref_out = torch.argmin(ref_inp, dim=dim, keepdim=keepdim) + res_out = argmin(inp, dim=dim, keepdim=keepdim) + check("value", ref_out, res_out) diff --git a/python/test/ops/argmin/argmin_ascend_perf.py b/python/test/ops/argmin/argmin_ascend_perf.py new file mode 100644 index 000000000..6f779595f --- /dev/null +++ b/python/test/ops/argmin/argmin_ascend_perf.py @@ -0,0 +1,154 @@ +import math + +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def get_dtype_max(dtype: tl.constexpr): + """get a value which is greater that all other values of that dtype""" + # extract the tl.dtype from tl.constexpr so as to use its methods + dtype_ = dtype.value + if dtype_.is_floating(): + value: tl.constexpr = float("inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2 ** (width - 1) - 1 + return value + if dtype_.is_int_unsigned(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2**width - 1 + return value + + +@triton.jit +def argmin_kernel( + inp, + out_index, + M, + N, + K, + BLOCK_M: tl.constexpr = 8, + BLOCK_N: tl.constexpr = 16, +): + # set offset + pid_m = tl.program_id(0) + # pid_k = tl.program_id(1) + for pid_k in range(K): + m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + dtype = inp.type.element_ty + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + max_value = get_dtype_max(dtype) + min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) + argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) + for start_n in range(0, N, BLOCK_N): + n_offset = start_n + tl.arange(0, BLOCK_N) + offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k + mask = m_offset[:, None] < M and n_offset[None, :] < N + inp_ptrs = inp + offset + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) + # tl.bfloat is promoted to tl.float32 by tl.min + local_min, local_argmin = tl.min( + inp_vals, 1, return_indices=True, return_indices_tie_break_left=True + ) + # if return indices is not supported, call a tl.argmin in addition + # local_argmin = tl.argmin(inp_vals, 1) + update = local_min < min_values + min_values = tl.where(update, local_min, min_values) + argmin_values = tl.where(update, start_n + local_argmin, argmin_values) + + offset_index = m_offset * K + pid_k + out_index_ptrs = out_index + offset_index + mask1 = m_offset < M + tl.store(out_index_ptrs, argmin_values, mask=mask1) + + +def argmin(inp, dim=None, keepdim=False, *, dtype=None): + if dim is not None: + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + shape = inp.shape + dim = dim % inp.ndim + N = shape[dim] + M = math.prod(shape[:dim]) + K = inp.numel() // M // N + + inp = inp.contiguous() + + shape_list = list(shape) + shape_list[dim] = 1 + out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) + if not keepdim: + out_index = torch.squeeze(out_index, dim) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + # K, + ) + with torch_device_fn.device(inp.device): + argmin_kernel[grid]( + inp, + out_index, + M, + N, + K, + ) + + return out_index + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + shape = (1, 32) + dim = 1 + keepdim = True + dtype = torch.float32 + + # inp + inp = torch.randn(shape, dtype=dtype, device=device) + ref_inp = inp.cpu() + + # op + ref_out = torch.argmin(ref_inp, dim=dim, keepdim=keepdim) + res_out = argmin(inp, dim=dim, keepdim=keepdim) + check("value", ref_out, res_out) diff --git a/python/test/ops/argmin/triton-ascend-perf/argmin_kernel.ttadapter b/python/test/ops/argmin/triton-ascend-perf/argmin_kernel.ttadapter new file mode 100644 index 000000000..285b294d3 --- /dev/null +++ b/python/test/ops/argmin/triton-ascend-perf/argmin_kernel.ttadapter @@ -0,0 +1,88 @@ +#map = affine_map<(d0) -> (d0)> +module { + func.func @argmin_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c-1_i32 = arith.constant -1 : i32 + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst = arith.constant 0x7F800000 : f32 + %0 = tensor.empty() : tensor<8xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8xf32>) -> tensor<8xf32> + %2 = tensor.empty() : tensor<8xi64> + %3 = linalg.fill ins(%c0_i64 : i64) outs(%2 : tensor<8xi64>) -> tensor<8xi64> + %4 = arith.muli %arg8, %c8_i32 : i32 + %5 = tensor.empty() : tensor<16xi32> + %6 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%5 : tensor<16xi32>) { + ^bb0(%out: i32): + %14 = linalg.index 0 : index + %15 = arith.index_cast %14 : index to i32 + linalg.yield %15 : i32 + } -> tensor<16xi32> + %7 = tensor.empty() : tensor<8x16xi32> + %broadcasted = linalg.broadcast ins(%6 : tensor<16xi32>) outs(%7 : tensor<8x16xi32>) dimensions = [0] + %8:2 = scf.for %arg11 = %c0_i32 to %arg4 step %c16_i32 iter_args(%arg12 = %1, %arg13 = %3) -> (tensor<8xf32>, tensor<8xi64>) : i32 { + %14 = arith.index_cast %4 : i32 to index + %15 = arith.index_cast %arg4 : i32 to index + %16 = arith.muli %14, %15 : index + %17 = arith.index_cast %arg11 : i32 to index + %18 = arith.addi %16, %17 : index + %reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [%18], sizes: [8, 16], strides: [%15, 1] : memref to memref<8x16xf32, strided<[?, 1], offset: ?>> + %alloc = memref.alloc() : memref<8x16xf32> + %19 = arith.addi %14, %c8 : index + %20 = arith.maxsi %14, %c1 : index + %21 = arith.minsi %19, %20 : index + %22 = arith.subi %21, %14 : index + %23 = arith.addi %17, %c16 : index + %24 = arith.maxsi %17, %15 : index + %25 = arith.minsi %23, %24 : index + %26 = arith.subi %25, %17 : index + %27 = arith.minsi %22, %c8 : index + %28 = arith.minsi %26, %c16 : index + %29 = arith.cmpi slt, %27, %c8 : index + %30 = arith.cmpi slt, %28, %c16 : index + %31 = arith.ori %29, %30 : i1 + scf.if %31 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<8x16xf32>) + } + %subview_1 = memref.subview %reinterpret_cast_0[0, 0] [%27, %28] [1, 1] : memref<8x16xf32, strided<[?, 1], offset: ?>> to memref> + %subview_2 = memref.subview %alloc[0, 0] [%27, %28] [1, 1] : memref<8x16xf32> to memref> + memref.copy %subview_1, %subview_2 : memref> to memref> + %32 = bufferization.to_tensor %alloc restrict writable : memref<8x16xf32> + %33 = tensor.empty() : tensor<8xi32> + %34 = linalg.fill ins(%c-1_i32 : i32) outs(%33 : tensor<8xi32>) -> tensor<8xi32> + %reduced:2 = linalg.reduce ins(%32, %broadcasted : tensor<8x16xf32>, tensor<8x16xi32>) outs(%1, %34 : tensor<8xf32>, tensor<8xi32>) dimensions = [1] {reduce_mode = "min_with_index"} + (%in: f32, %in_3: i32, %init: f32, %init_4: i32) { + %41 = arith.cmpf olt, %in, %init : f32 + %42 = arith.cmpf oeq, %in, %init : f32 + %43 = arith.cmpi slt, %in_3, %init_4 : i32 + %44 = arith.andi %42, %43 : i1 + %45 = arith.ori %41, %44 : i1 + %46 = arith.select %45, %in, %init : f32 + %47 = arith.select %45, %in_3, %init_4 : i32 + linalg.yield %46, %47 : f32, i32 + } + %35 = arith.cmpf olt, %reduced#0, %arg12 : tensor<8xf32> + %36 = arith.select %35, %reduced#0, %arg12 : tensor<8xi1>, tensor<8xf32> + %37 = linalg.fill ins(%arg11 : i32) outs(%33 : tensor<8xi32>) -> tensor<8xi32> + %38 = arith.addi %37, %reduced#1 : tensor<8xi32> + %39 = arith.extsi %38 : tensor<8xi32> to tensor<8xi64> + %40 = arith.select %35, %39, %arg13 : tensor<8xi1>, tensor<8xi64> + scf.yield %36, %40 : tensor<8xf32>, tensor<8xi64> + } + %9 = arith.index_cast %4 : i32 to index + %reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%9], sizes: [8], strides: [1] : memref to memref<8xi64, strided<[1], offset: ?>> + %10 = arith.addi %9, %c8 : index + %11 = arith.maxsi %9, %c1 : index + %12 = arith.minsi %10, %11 : index + %13 = arith.subi %12, %9 : index + %extracted_slice = tensor.extract_slice %8#1[0] [%13] [1] : tensor<8xi64> to tensor + %subview = memref.subview %reinterpret_cast[0] [%13] [1] : memref<8xi64, strided<[1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/argmin/triton-ascend-perf/argmin_kernel.ttir b/python/test/ops/argmin/triton-ascend-perf/argmin_kernel.ttir new file mode 100644 index 000000000..5d1578cef --- /dev/null +++ b/python/test/ops/argmin/triton-ascend-perf/argmin_kernel.ttir @@ -0,0 +1,123 @@ +#loc = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":37:0) +#loc1 = loc(unknown) +#loc15 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":65:26) +#loc38 = loc(callsite(#loc1 at #loc15)) +module { + tt.func public @argmin_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":37:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":37:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":37:0)) attributes {noinline = false} { + %cst = arith.constant dense<0x7F800000> : tensor<8x16xf32> loc(#loc1) + %cst_0 = arith.constant dense<0x7F800000> : tensor<8xf32> loc(#loc1) + %c16_i32 = arith.constant 16 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_1 = arith.constant dense<1> : tensor<8xi32> loc(#loc1) + %cst_2 = arith.constant dense<1> : tensor<8x1xi32> loc(#loc1) + %cst_3 = arith.constant dense<0> : tensor<8xi64> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c8_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc4) + %3 = tt.splat %1 : i32 -> tensor<8xi32> loc(#loc5) + %4 = arith.addi %3, %2 : tensor<8xi32> loc(#loc5) + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc6) + %6 = tt.expand_dims %4 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc7) + %7 = tt.splat %arg2 : i32 -> tensor<8x1xi32> loc(#loc8) + %8 = arith.muli %6, %7 : tensor<8x1xi32> loc(#loc8) + %9 = tt.broadcast %8 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc9) + %10 = arith.cmpi slt, %6, %cst_2 : tensor<8x1xi32> loc(#loc10) + %11 = tt.splat %arg2 : i32 -> tensor<1x16xi32> loc(#loc11) + %12 = tt.broadcast %10 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc12) + %13 = tt.splat %arg0 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc13) + %14 = tt.expand_dims %5 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc37) + %15 = tt.broadcast %14 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc37) + %16:2 = scf.for %arg3 = %c0_i32 to %arg2 step %c16_i32 iter_args(%arg4 = %cst_0, %arg5 = %cst_3) -> (tensor<8xf32>, tensor<8xi64>) : i32 { + %20 = tt.splat %arg3 : i32 -> tensor<16xi32> loc(#loc17) + %21 = arith.addi %20, %5 : tensor<16xi32> loc(#loc17) + %22 = tt.expand_dims %21 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc18) + %23 = tt.broadcast %22 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc9) + %24 = arith.addi %9, %23 : tensor<8x16xi32> loc(#loc9) + %25 = arith.cmpi slt, %22, %11 : tensor<1x16xi32> loc(#loc11) + %26 = tt.broadcast %25 : tensor<1x16xi1> -> tensor<8x16xi1> loc(#loc12) + %27 = arith.andi %12, %26 : tensor<8x16xi1> loc(#loc12) + %28 = tt.addptr %13, %24 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc13) + %29 = tt.load %28, %27, %cst : tensor<8x16x!tt.ptr> loc(#loc19) + %30:2 = "tt.reduce"(%29, %15) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32 loc(callsite(#loc1 at #loc15)), %arg7: i32 loc(callsite(#loc1 at #loc15)), %arg8: f32 loc(callsite(#loc1 at #loc15)), %arg9: i32 loc(callsite(#loc1 at #loc15))): + %37 = arith.cmpf oeq, %arg6, %arg8 : f32 loc(#loc53) + %38 = arith.cmpi slt, %arg7, %arg9 : i32 loc(#loc54) + %39 = arith.andi %37, %38 : i1 loc(#loc55) + %40 = arith.cmpf olt, %arg6, %arg8 : f32 loc(#loc56) + %41 = arith.ori %40, %39 : i1 loc(#loc57) + %42 = arith.select %41, %arg6, %arg8 : f32 loc(#loc58) + %43 = arith.select %41, %arg7, %arg9 : i32 loc(#loc59) + tt.reduce.return %42, %43 : f32, i32 loc(#loc37) + }) : (tensor<8x16xf32>, tensor<8x16xi32>) -> (tensor<8xf32>, tensor<8xi32>) loc(#loc37) + %31 = arith.cmpf olt, %30#0, %arg4 : tensor<8xf32> loc(#loc28) + %32 = arith.select %31, %30#0, %arg4 : tensor<8xi1>, tensor<8xf32> loc(#loc29) + %33 = tt.splat %arg3 : i32 -> tensor<8xi32> loc(#loc30) + %34 = arith.addi %33, %30#1 : tensor<8xi32> loc(#loc30) + %35 = arith.extsi %34 : tensor<8xi32> to tensor<8xi64> loc(#loc31) + %36 = arith.select %31, %35, %arg5 : tensor<8xi1>, tensor<8xi64> loc(#loc31) + scf.yield %32, %36 : tensor<8xf32>, tensor<8xi64> loc(#loc32) + } loc(#loc16) + %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> loc(#loc33) + %18 = tt.addptr %17, %4 : tensor<8x!tt.ptr>, tensor<8xi32> loc(#loc33) + %19 = arith.cmpi slt, %4, %cst_1 : tensor<8xi32> loc(#loc34) + tt.store %18, %16#1, %19 : tensor<8x!tt.ptr> loc(#loc35) + tt.return loc(#loc36) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":47:26) +#loc3 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":50:27) +#loc4 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":50:50) +#loc5 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":50:37) +#loc6 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":58:46) +#loc7 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":59:30) +#loc8 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":59:41) +#loc9 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":59:49) +#loc10 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":60:39) +#loc11 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":60:65) +#loc12 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":60:45) +#loc13 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":61:29) +#loc14 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":233:58) +#loc16 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":57:35) +#loc17 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":58:33) +#loc18 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":59:58) +#loc19 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":62:31) +#loc20 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:24) +#loc21 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":212:59) +#loc22 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:44) +#loc23 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:35) +#loc24 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":204:18) +#loc25 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":204:28) +#loc26 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":205:39) +#loc27 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":206:39) +#loc28 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":69:33) +#loc29 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":70:53) +#loc30 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":71:55) +#loc31 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":71:69) +#loc32 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":71:12) +#loc33 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":74:37) +#loc34 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":75:27) +#loc35 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":76:33) +#loc36 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin_ascend_perf.py":49:4) +#loc37 = loc(callsite(#loc14 at #loc15)) +#loc39 = loc(callsite(#loc20 at #loc21)) +#loc40 = loc(callsite(#loc22 at #loc21)) +#loc41 = loc(callsite(#loc23 at #loc21)) +#loc42 = loc(callsite(#loc24 at #loc21)) +#loc43 = loc(callsite(#loc25 at #loc21)) +#loc44 = loc(callsite(#loc26 at #loc21)) +#loc45 = loc(callsite(#loc27 at #loc21)) +#loc46 = loc(callsite(#loc39 at #loc14)) +#loc47 = loc(callsite(#loc40 at #loc14)) +#loc48 = loc(callsite(#loc41 at #loc14)) +#loc49 = loc(callsite(#loc42 at #loc14)) +#loc50 = loc(callsite(#loc43 at #loc14)) +#loc51 = loc(callsite(#loc44 at #loc14)) +#loc52 = loc(callsite(#loc45 at #loc14)) +#loc53 = loc(callsite(#loc46 at #loc15)) +#loc54 = loc(callsite(#loc47 at #loc15)) +#loc55 = loc(callsite(#loc48 at #loc15)) +#loc56 = loc(callsite(#loc49 at #loc15)) +#loc57 = loc(callsite(#loc50 at #loc15)) +#loc58 = loc(callsite(#loc51 at #loc15)) +#loc59 = loc(callsite(#loc52 at #loc15)) diff --git a/python/test/ops/argmin/triton-ascend/argmin_kernel.ttadapter b/python/test/ops/argmin/triton-ascend/argmin_kernel.ttadapter new file mode 100644 index 000000000..e8b795c00 --- /dev/null +++ b/python/test/ops/argmin/triton-ascend/argmin_kernel.ttadapter @@ -0,0 +1,92 @@ +#map = affine_map<(d0) -> (d0)> +module { + func.func @argmin_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c-1_i32 = arith.constant -1 : i32 + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst = arith.constant 0x7F800000 : f32 + %0 = tensor.empty() : tensor<8xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8xf32>) -> tensor<8xf32> + %2 = tensor.empty() : tensor<8xi64> + %3 = linalg.fill ins(%c0_i64 : i64) outs(%2 : tensor<8xi64>) -> tensor<8xi64> + %4 = arith.muli %arg8, %c8_i32 : i32 + %5 = tensor.empty() : tensor<16xi32> + %6 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%5 : tensor<16xi32>) { + ^bb0(%out: i32): + %16 = linalg.index 0 : index + %17 = arith.index_cast %16 : index to i32 + linalg.yield %17 : i32 + } -> tensor<16xi32> + %7 = tensor.empty() : tensor<8x16xi32> + %broadcasted = linalg.broadcast ins(%6 : tensor<16xi32>) outs(%7 : tensor<8x16xi32>) dimensions = [0] + %8:2 = scf.for %arg11 = %c0_i32 to %arg4 step %c16_i32 iter_args(%arg12 = %1, %arg13 = %3) -> (tensor<8xf32>, tensor<8xi64>) : i32 { + %16 = arith.index_cast %4 : i32 to index + %17 = arith.index_cast %arg4 : i32 to index + %18 = arith.muli %16, %17 : index + %19 = arith.index_cast %arg11 : i32 to index + %20 = arith.index_cast %arg9 : i32 to index + %21 = arith.addi %18, %20 : index + %22 = arith.addi %21, %19 : index + %reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [%22], sizes: [8, 16], strides: [%17, 1] : memref to memref<8x16xf32, strided<[?, 1], offset: ?>> + %alloc = memref.alloc() : memref<8x16xf32> + %23 = arith.addi %16, %c8 : index + %24 = arith.maxsi %16, %c1 : index + %25 = arith.minsi %23, %24 : index + %26 = arith.subi %25, %16 : index + %27 = arith.addi %19, %c16 : index + %28 = arith.maxsi %19, %17 : index + %29 = arith.minsi %27, %28 : index + %30 = arith.subi %29, %19 : index + %31 = arith.minsi %26, %c8 : index + %32 = arith.minsi %30, %c16 : index + %33 = arith.cmpi slt, %31, %c8 : index + %34 = arith.cmpi slt, %32, %c16 : index + %35 = arith.ori %33, %34 : i1 + scf.if %35 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<8x16xf32>) + } + %subview_1 = memref.subview %reinterpret_cast_0[0, 0] [%31, %32] [1, 1] : memref<8x16xf32, strided<[?, 1], offset: ?>> to memref> + %subview_2 = memref.subview %alloc[0, 0] [%31, %32] [1, 1] : memref<8x16xf32> to memref> + memref.copy %subview_1, %subview_2 : memref> to memref> + %36 = bufferization.to_tensor %alloc restrict writable : memref<8x16xf32> + %37 = tensor.empty() : tensor<8xi32> + %38 = linalg.fill ins(%c-1_i32 : i32) outs(%37 : tensor<8xi32>) -> tensor<8xi32> + %reduced:2 = linalg.reduce ins(%36, %broadcasted : tensor<8x16xf32>, tensor<8x16xi32>) outs(%1, %38 : tensor<8xf32>, tensor<8xi32>) dimensions = [1] {reduce_mode = "min_with_index"} + (%in: f32, %in_3: i32, %init: f32, %init_4: i32) { + %45 = arith.cmpf olt, %in, %init : f32 + %46 = arith.cmpf oeq, %in, %init : f32 + %47 = arith.cmpi slt, %in_3, %init_4 : i32 + %48 = arith.andi %46, %47 : i1 + %49 = arith.ori %45, %48 : i1 + %50 = arith.select %49, %in, %init : f32 + %51 = arith.select %49, %in_3, %init_4 : i32 + linalg.yield %50, %51 : f32, i32 + } + %39 = arith.cmpf olt, %reduced#0, %arg12 : tensor<8xf32> + %40 = arith.select %39, %reduced#0, %arg12 : tensor<8xi1>, tensor<8xf32> + %41 = linalg.fill ins(%arg11 : i32) outs(%37 : tensor<8xi32>) -> tensor<8xi32> + %42 = arith.addi %41, %reduced#1 : tensor<8xi32> + %43 = arith.extsi %42 : tensor<8xi32> to tensor<8xi64> + %44 = arith.select %39, %43, %arg13 : tensor<8xi1>, tensor<8xi64> + scf.yield %40, %44 : tensor<8xf32>, tensor<8xi64> + } + %9 = arith.index_cast %4 : i32 to index + %10 = arith.index_cast %arg9 : i32 to index + %11 = arith.addi %9, %10 : index + %reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%11], sizes: [8], strides: [1] : memref to memref<8xi64, strided<[1], offset: ?>> + %12 = arith.addi %9, %c8 : index + %13 = arith.maxsi %9, %c1 : index + %14 = arith.minsi %12, %13 : index + %15 = arith.subi %14, %9 : index + %extracted_slice = tensor.extract_slice %8#1[0] [%15] [1] : tensor<8xi64> to tensor + %subview = memref.subview %reinterpret_cast[0] [%15] [1] : memref<8xi64, strided<[1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/argmin/triton-ascend/argmin_kernel.ttir b/python/test/ops/argmin/triton-ascend/argmin_kernel.ttir new file mode 100644 index 000000000..2ce563346 --- /dev/null +++ b/python/test/ops/argmin/triton-ascend/argmin_kernel.ttir @@ -0,0 +1,131 @@ +#loc = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":37:0) +#loc1 = loc(unknown) +#loc17 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":64:22) +#loc41 = loc(callsite(#loc1 at #loc17)) +module { + tt.func public @argmin_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":37:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":37:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":37:0)) attributes {noinline = false} { + %cst = arith.constant dense<0x7F800000> : tensor<8x16xf32> loc(#loc1) + %cst_0 = arith.constant dense<0x7F800000> : tensor<8xf32> loc(#loc1) + %c16_i32 = arith.constant 16 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_1 = arith.constant dense<1> : tensor<8xi32> loc(#loc1) + %cst_2 = arith.constant dense<1> : tensor<8x1xi32> loc(#loc1) + %cst_3 = arith.constant dense<0> : tensor<8xi64> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = tt.get_program_id y : i32 loc(#loc3) + %2 = arith.muli %0, %c8_i32 : i32 loc(#loc4) + %3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc5) + %4 = tt.splat %2 : i32 -> tensor<8xi32> loc(#loc6) + %5 = arith.addi %4, %3 : tensor<8xi32> loc(#loc6) + %6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc7) + %7 = tt.expand_dims %5 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc8) + %8 = tt.splat %arg2 : i32 -> tensor<8x1xi32> loc(#loc9) + %9 = arith.muli %7, %8 : tensor<8x1xi32> loc(#loc9) + %10 = tt.broadcast %9 : tensor<8x1xi32> -> tensor<8x16xi32> loc(#loc10) + %11 = tt.splat %1 : i32 -> tensor<8x16xi32> loc(#loc11) + %12 = arith.cmpi slt, %7, %cst_2 : tensor<8x1xi32> loc(#loc12) + %13 = tt.splat %arg2 : i32 -> tensor<1x16xi32> loc(#loc13) + %14 = tt.broadcast %12 : tensor<8x1xi1> -> tensor<8x16xi1> loc(#loc14) + %15 = tt.splat %arg0 : !tt.ptr -> tensor<8x16x!tt.ptr> loc(#loc15) + %16 = tt.expand_dims %6 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc40) + %17 = tt.broadcast %16 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc40) + %18:2 = scf.for %arg3 = %c0_i32 to %arg2 step %c16_i32 iter_args(%arg4 = %cst_0, %arg5 = %cst_3) -> (tensor<8xf32>, tensor<8xi64>) : i32 { + %24 = tt.splat %arg3 : i32 -> tensor<16xi32> loc(#loc19) + %25 = arith.addi %24, %6 : tensor<16xi32> loc(#loc19) + %26 = tt.expand_dims %25 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc20) + %27 = tt.broadcast %26 : tensor<1x16xi32> -> tensor<8x16xi32> loc(#loc10) + %28 = arith.addi %10, %27 : tensor<8x16xi32> loc(#loc10) + %29 = arith.addi %28, %11 : tensor<8x16xi32> loc(#loc11) + %30 = arith.cmpi slt, %26, %13 : tensor<1x16xi32> loc(#loc13) + %31 = tt.broadcast %30 : tensor<1x16xi1> -> tensor<8x16xi1> loc(#loc14) + %32 = arith.andi %14, %31 : tensor<8x16xi1> loc(#loc14) + %33 = tt.addptr %15, %29 : tensor<8x16x!tt.ptr>, tensor<8x16xi32> loc(#loc15) + %34 = tt.load %33, %32, %cst : tensor<8x16x!tt.ptr> loc(#loc21) + %35:2 = "tt.reduce"(%34, %17) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32 loc(callsite(#loc1 at #loc17)), %arg7: i32 loc(callsite(#loc1 at #loc17)), %arg8: f32 loc(callsite(#loc1 at #loc17)), %arg9: i32 loc(callsite(#loc1 at #loc17))): + %42 = arith.cmpf oeq, %arg6, %arg8 : f32 loc(#loc56) + %43 = arith.cmpi slt, %arg7, %arg9 : i32 loc(#loc57) + %44 = arith.andi %42, %43 : i1 loc(#loc58) + %45 = arith.cmpf olt, %arg6, %arg8 : f32 loc(#loc59) + %46 = arith.ori %45, %44 : i1 loc(#loc60) + %47 = arith.select %46, %arg6, %arg8 : f32 loc(#loc61) + %48 = arith.select %46, %arg7, %arg9 : i32 loc(#loc62) + tt.reduce.return %47, %48 : f32, i32 loc(#loc40) + }) : (tensor<8x16xf32>, tensor<8x16xi32>) -> (tensor<8xf32>, tensor<8xi32>) loc(#loc40) + %36 = arith.cmpf olt, %35#0, %arg4 : tensor<8xf32> loc(#loc30) + %37 = arith.select %36, %35#0, %arg4 : tensor<8xi1>, tensor<8xf32> loc(#loc31) + %38 = tt.splat %arg3 : i32 -> tensor<8xi32> loc(#loc32) + %39 = arith.addi %38, %35#1 : tensor<8xi32> loc(#loc32) + %40 = arith.extsi %39 : tensor<8xi32> to tensor<8xi64> loc(#loc33) + %41 = arith.select %36, %40, %arg5 : tensor<8xi1>, tensor<8xi64> loc(#loc33) + scf.yield %37, %41 : tensor<8xf32>, tensor<8xi64> loc(#loc34) + } loc(#loc18) + %19 = tt.splat %1 : i32 -> tensor<8xi32> loc(#loc35) + %20 = arith.addi %5, %19 : tensor<8xi32> loc(#loc35) + %21 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> loc(#loc36) + %22 = tt.addptr %21, %20 : tensor<8x!tt.ptr>, tensor<8xi32> loc(#loc36) + %23 = arith.cmpi slt, %5, %cst_1 : tensor<8xi32> loc(#loc37) + tt.store %22, %18#1, %23 : tensor<8x!tt.ptr> loc(#loc38) + tt.return loc(#loc39) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":47:26) +#loc3 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":48:26) +#loc4 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":49:23) +#loc5 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":49:46) +#loc6 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":49:33) +#loc7 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":57:42) +#loc8 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":58:26) +#loc9 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":58:37) +#loc10 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":58:45) +#loc11 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":58:69) +#loc12 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":59:35) +#loc13 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":59:61) +#loc14 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":59:41) +#loc15 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":60:25) +#loc16 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":233:58) +#loc18 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":56:31) +#loc19 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":57:29) +#loc20 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":58:54) +#loc21 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":61:27) +#loc22 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:24) +#loc23 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":212:59) +#loc24 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:44) +#loc25 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:35) +#loc26 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":204:18) +#loc27 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":204:28) +#loc28 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":205:39) +#loc29 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":206:39) +#loc30 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":68:29) +#loc31 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":69:49) +#loc32 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":70:51) +#loc33 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":70:65) +#loc34 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":70:8) +#loc35 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":72:34) +#loc36 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":73:33) +#loc37 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":74:23) +#loc38 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":75:29) +#loc39 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/argmin/argmin.py":75:4) +#loc40 = loc(callsite(#loc16 at #loc17)) +#loc42 = loc(callsite(#loc22 at #loc23)) +#loc43 = loc(callsite(#loc24 at #loc23)) +#loc44 = loc(callsite(#loc25 at #loc23)) +#loc45 = loc(callsite(#loc26 at #loc23)) +#loc46 = loc(callsite(#loc27 at #loc23)) +#loc47 = loc(callsite(#loc28 at #loc23)) +#loc48 = loc(callsite(#loc29 at #loc23)) +#loc49 = loc(callsite(#loc42 at #loc16)) +#loc50 = loc(callsite(#loc43 at #loc16)) +#loc51 = loc(callsite(#loc44 at #loc16)) +#loc52 = loc(callsite(#loc45 at #loc16)) +#loc53 = loc(callsite(#loc46 at #loc16)) +#loc54 = loc(callsite(#loc47 at #loc16)) +#loc55 = loc(callsite(#loc48 at #loc16)) +#loc56 = loc(callsite(#loc49 at #loc17)) +#loc57 = loc(callsite(#loc50 at #loc17)) +#loc58 = loc(callsite(#loc51 at #loc17)) +#loc59 = loc(callsite(#loc52 at #loc17)) +#loc60 = loc(callsite(#loc53 at #loc17)) +#loc61 = loc(callsite(#loc54 at #loc17)) +#loc62 = loc(callsite(#loc55 at #loc17)) diff --git a/python/test/ops/bmm/bmm.py b/python/test/ops/bmm/bmm.py new file mode 100644 index 000000000..10d06ab45 --- /dev/null +++ b/python/test/ops/bmm/bmm.py @@ -0,0 +1,177 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def bmm_kernel( + A, + B, + O, + M, + N, + K, + TILE_M: tl.constexpr = 64, + TILE_N: tl.constexpr = 64, + TILE_K: tl.constexpr = 64, + GROUP_M: tl.constexpr = 1, + DIVISIBLE_M: tl.constexpr = True, + DIVISIBLE_N: tl.constexpr = True, + DIVISIBLE_K: tl.constexpr = True, +): + # batch offsets + pid_b = tl.program_id(2) + A += pid_b * M * K + B += pid_b * K * N + O += pid_b * M * N + + pidx = tl.program_id(0) + pidy = tl.program_id(1) + + if GROUP_M == 1: + pid_m, pid_n = pidx, pidy + else: + # reorder CTAs + gridx = tl.num_programs(0) + gridy = tl.num_programs(1) + pid = pidx + pidy * gridx + + num_CTA_per_group = gridy * GROUP_M + + group_id = pid // num_CTA_per_group + inner_group_id = pid % num_CTA_per_group + GROUP_SIZE = tl.where( + (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M + ) + pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE + pid_n = inner_group_id // GROUP_SIZE + + offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) + offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) + offs_k = tl.arange(0, TILE_K) + + if not DIVISIBLE_M: + mask_m = offs_m < M + if not DIVISIBLE_N: + mask_n = offs_n < N + + a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] + o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] + + num_iters = tl.cdiv(K, TILE_K) + o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for _ in range(num_iters): + if DIVISIBLE_K: + if DIVISIBLE_M: + mask_a = None + else: + mask_a = mask_m[:, None] + if DIVISIBLE_N: + mask_b = None + else: + mask_b = mask_n[None, :] + else: + mask_k = offs_k < K + if DIVISIBLE_M: + mask_a = mask_k[None, :] + else: + mask_a = mask_m[:, None] & mask_k[None, :] + if DIVISIBLE_N: + mask_b = mask_k[:, None] + else: + mask_b = mask_k[:, None] & mask_n[None, :] + + a = tl.load(a_ptrs, mask_a) + b = tl.load(b_ptrs, mask_b) + + offs_k += TILE_K + a_ptrs += TILE_K + b_ptrs += TILE_K * N + + o += tl.dot(a, b, allow_tf32=False) + + if DIVISIBLE_M and DIVISIBLE_N: + mask_c = None + elif DIVISIBLE_M and not DIVISIBLE_N: + mask_c = mask_n[None, :] + elif not DIVISIBLE_M and DIVISIBLE_N: + mask_c = mask_m[:, None] + else: + mask_c = mask_m[:, None] & mask_n[None, :] + tl.store(o_ptrs, o, mask_c) + + +def bmm(A, B): + batch, M, K = A.shape + _, _, N = B.shape + A = A.contiguous() + B = B.contiguous() + out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) + + grid_fn = lambda meta: ( + triton.cdiv(meta["M"], meta["TILE_M"]), + triton.cdiv(meta["N"], meta["TILE_N"]), + batch, + ) + with torch_device_fn.device(A.device): + bmm_kernel[grid_fn](A, B, out, M, N, K) + return out + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + M = 4 + N = 2 + K = 8 + dtype = torch.float32 + batch = 4 + + # inp + mat1 = torch.randn((batch, M, K), dtype=dtype, device=device) + mat2 = torch.randn((batch, K, N), dtype=dtype, device=device) + ref_mat1 = mat1.cpu() + ref_mat2 = mat2.cpu() + + # op + ref_out = torch.bmm(ref_mat1, ref_mat2) + res_out = bmm(mat1, mat2) + check("value", ref_out, res_out, reduce_dim=K) diff --git a/python/test/ops/bmm/bmm_ascend.py b/python/test/ops/bmm/bmm_ascend.py new file mode 100644 index 000000000..70b798979 --- /dev/null +++ b/python/test/ops/bmm/bmm_ascend.py @@ -0,0 +1,147 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def bmm_kernel( + A, + B, + O, + M, + N, + K, + TILE_M: tl.constexpr = 64, + TILE_N: tl.constexpr = 64, + TILE_K: tl.constexpr = 64, + GROUP_M: tl.constexpr = 1, + DIVISIBLE_M: tl.constexpr = True, + DIVISIBLE_N: tl.constexpr = True, + DIVISIBLE_K: tl.constexpr = True, +): + # batch offsets + pid_b = tl.program_id(2) + A += pid_b * M * K + B += pid_b * K * N + O += pid_b * M * N + + pidx = tl.program_id(0) + pidy = tl.program_id(1) + if GROUP_M == 1: + pid_m, pid_n = pidx, pidy + else: + # reorder CTAs + gridx = tl.num_programs(0) + gridy = tl.num_programs(1) + pid = pidx + pidy * gridx + + num_CTA_per_group = gridy * GROUP_M + + group_id = pid // num_CTA_per_group + inner_group_id = pid % num_CTA_per_group + GROUP_SIZE = tl.where( + (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M + ) + pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE + pid_n = inner_group_id // GROUP_SIZE + + offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) + offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) + offs_k = tl.arange(0, TILE_K) + + a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] + o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] + + num_iters = tl.cdiv(K, TILE_K) + o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for i in range(num_iters): + mask_a = offs_k[None, :] < K - i * TILE_K + mask_b = offs_k[:, None] < K - i * TILE_K + a = tl.load(a_ptrs, mask=mask_a) + b = tl.load(b_ptrs, mask=mask_b) + + a_ptrs += TILE_K + b_ptrs += TILE_K * N + + o += tl.dot(a, b, allow_tf32=False) + + mask_m = (pid_m * TILE_M + tl.arange(0, TILE_M)) < M + mask_n = (pid_n * TILE_N + tl.arange(0, TILE_N)) < N + mask_c = mask_m[:, None] & mask_n[None, :] + tl.store(o_ptrs, o, mask_c) + + +def bmm(A, B): + batch, M, K = A.shape + _, _, N = B.shape + A = A.contiguous() + B = B.contiguous() + out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) + + grid_fn = lambda meta: ( + triton.cdiv(meta["M"], meta["TILE_M"]), + triton.cdiv(meta["N"], meta["TILE_N"]), + batch, + ) + with torch_device_fn.device(A.device): + bmm_kernel[grid_fn](A, B, out, M, N, K) + return out + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + M = 4 + N = 2 + K = 8 + dtype = torch.float32 + batch = 4 + + # inp + mat1 = torch.randn((batch, M, K), dtype=dtype, device=device) + mat2 = torch.randn((batch, K, N), dtype=dtype, device=device) + ref_mat1 = mat1.cpu() + ref_mat2 = mat2.cpu() + + # op + ref_out = torch.bmm(ref_mat1, ref_mat2) + res_out = bmm(mat1, mat2) + check("value", ref_out, res_out, reduce_dim=K) diff --git a/python/test/ops/bmm/triton-ascend/bmm_kernel.ttadapter b/python/test/ops/bmm/triton-ascend/bmm_kernel.ttadapter new file mode 100644 index 000000000..6d62b2ce2 --- /dev/null +++ b/python/test/ops/bmm/triton-ascend/bmm_kernel.ttadapter @@ -0,0 +1,89 @@ +module { + func.func @bmm_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "mix"} { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<64x64xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x64xf32>) -> tensor<64x64xf32> + %2 = arith.muli %arg13, %arg5 : i32 + %3 = arith.muli %2, %arg7 : i32 + %4 = arith.muli %arg13, %arg7 : i32 + %5 = arith.muli %4, %arg6 : i32 + %6 = arith.muli %2, %arg6 : i32 + %7 = arith.index_cast %6 : i32 to index + %8 = arith.muli %arg11, %c64_i32 : i32 + %9 = arith.muli %arg12, %c64_i32 : i32 + %10 = arith.index_cast %8 : i32 to index + %11 = arith.index_cast %arg6 : i32 to index + %12 = arith.muli %10, %11 : index + %13 = arith.addi %7, %12 : index + %14 = arith.index_cast %9 : i32 to index + %15 = arith.addi %13, %14 : index + %reinterpret_cast = memref.reinterpret_cast %arg4 to offset: [%15], sizes: [64, 64], strides: [%11, 1] : memref to memref<64x64xf32, strided<[?, 1], offset: ?>> + %16 = arith.addi %arg7, %c63_i32 : i32 + %17 = arith.divsi %16, %c64_i32 : i32 + %18 = arith.muli %arg6, %c64_i32 : i32 + %19 = arith.index_cast %3 : i32 to index + %20 = arith.index_cast %arg7 : i32 to index + %21 = arith.muli %10, %20 : index + %22 = arith.addi %19, %21 : index + %23 = arith.index_cast %5 : i32 to index + %24:3 = scf.for %arg14 = %c0_i32 to %17 step %c1_i32 iter_args(%arg15 = %1, %arg16 = %22, %arg17 = %23) -> (tensor<64x64xf32>, index, index) : i32 { + %36 = arith.addi %arg17, %14 : index + %reinterpret_cast_0 = memref.reinterpret_cast %arg3 to offset: [%36], sizes: [64, 64], strides: [%11, %c1] : memref to memref<64x64xf32, strided<[?, ?], offset: ?>> + %reinterpret_cast_1 = memref.reinterpret_cast %arg2 to offset: [%arg16], sizes: [64, 64], strides: [%20, %c1] : memref to memref<64x64xf32, strided<[?, ?], offset: ?>> + %37 = arith.muli %arg14, %c64_i32 : i32 + %38 = arith.subi %arg7, %37 : i32 + %alloc = memref.alloc() : memref<64x64xf32> + %39 = arith.index_cast %38 : i32 to index + %40 = arith.maxsi %39, %c0 : index + %41 = arith.minsi %40, %c64 : index + %42 = arith.cmpi slt, %41, %c64 : index + scf.if %42 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<64x64xf32>) + } + %subview_2 = memref.subview %reinterpret_cast_1[0, 0] [64, %41] [1, 1] : memref<64x64xf32, strided<[?, ?], offset: ?>> to memref<64x?xf32, strided<[?, ?], offset: ?>> + %subview_3 = memref.subview %alloc[0, 0] [64, %41] [1, 1] : memref<64x64xf32> to memref<64x?xf32, strided<[64, 1]>> + memref.copy %subview_2, %subview_3 : memref<64x?xf32, strided<[?, ?], offset: ?>> to memref<64x?xf32, strided<[64, 1]>> + annotation.mark %subview_3 {MayImplicitTransposeWithLastAxis} : memref<64x?xf32, strided<[64, 1]>> + %43 = bufferization.to_tensor %alloc restrict writable : memref<64x64xf32> + annotation.mark %43 {MayImplicitTransposeWithLastAxis} : tensor<64x64xf32> + %alloc_4 = memref.alloc() : memref<64x64xf32> + scf.if %42 { + linalg.fill ins(%cst : f32) outs(%alloc_4 : memref<64x64xf32>) + } + %subview_5 = memref.subview %reinterpret_cast_0[0, 0] [%41, 64] [1, 1] : memref<64x64xf32, strided<[?, ?], offset: ?>> to memref> + %subview_6 = memref.subview %alloc_4[0, 0] [%41, 64] [1, 1] : memref<64x64xf32> to memref> + memref.copy %subview_5, %subview_6 : memref> to memref> + annotation.mark %subview_6 {MayImplicitTransposeWithLastAxis} : memref> + %44 = bufferization.to_tensor %alloc_4 restrict writable : memref<64x64xf32> + annotation.mark %44 {MayImplicitTransposeWithLastAxis} : tensor<64x64xf32> + %45 = linalg.matmul {input_precison = "ieee"} ins(%43, %44 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%arg15 : tensor<64x64xf32>) -> tensor<64x64xf32> + %46 = arith.addi %arg16, %c64 : index + %47 = arith.index_cast %18 : i32 to index + %48 = arith.addi %arg17, %47 : index + scf.yield %45, %46, %48 : tensor<64x64xf32>, index, index + } + %25 = arith.addi %10, %c64 : index + %26 = arith.index_cast %arg5 : i32 to index + %27 = arith.maxsi %10, %26 : index + %28 = arith.minsi %25, %27 : index + %29 = arith.subi %28, %10 : index + %30 = arith.addi %14, %c64 : index + %31 = arith.maxsi %14, %11 : index + %32 = arith.minsi %30, %31 : index + %33 = arith.subi %32, %14 : index + %34 = arith.minsi %29, %c64 : index + %35 = arith.minsi %33, %c64 : index + %extracted_slice = tensor.extract_slice %24#0[0, 0] [%34, %35] [1, 1] : tensor<64x64xf32> to tensor + %subview = memref.subview %reinterpret_cast[0, 0] [%34, %35] [1, 1] : memref<64x64xf32, strided<[?, 1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/bmm/triton-ascend/bmm_kernel.ttir b/python/test/ops/bmm/triton-ascend/bmm_kernel.ttir new file mode 100644 index 000000000..2b6f075ed --- /dev/null +++ b/python/test/ops/bmm/triton-ascend/bmm_kernel.ttir @@ -0,0 +1,137 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":17:0) +module { + tt.func public @bmm_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":17:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":17:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":17:0), %arg3: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":17:0), %arg4: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":17:0), %arg5: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":17:0)) attributes {noinline = false} { + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32> loc(#loc1) + %cst_0 = arith.constant dense<64> : tensor<64x64xi32> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %0 = tt.get_program_id z : i32 loc(#loc2) + %1 = arith.muli %0, %arg3 : i32 loc(#loc3) + %2 = arith.muli %1, %arg5 : i32 loc(#loc4) + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 loc(#loc5) + %4 = arith.muli %0, %arg5 : i32 loc(#loc6) + %5 = arith.muli %4, %arg4 : i32 loc(#loc7) + %6 = tt.addptr %arg1, %5 : !tt.ptr, i32 loc(#loc8) + %7 = arith.muli %1, %arg4 : i32 loc(#loc9) + %8 = tt.addptr %arg2, %7 : !tt.ptr, i32 loc(#loc10) + %9 = tt.get_program_id x : i32 loc(#loc11) + %10 = tt.get_program_id y : i32 loc(#loc12) + %11 = arith.muli %9, %c64_i32 : i32 loc(#loc13) + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc14) + %13 = tt.splat %11 : i32 -> tensor<64xi32> loc(#loc15) + %14 = arith.addi %13, %12 : tensor<64xi32> loc(#loc15) + %15 = arith.muli %10, %c64_i32 : i32 loc(#loc16) + %16 = tt.splat %15 : i32 -> tensor<64xi32> loc(#loc17) + %17 = arith.addi %16, %12 : tensor<64xi32> loc(#loc17) + %18 = tt.expand_dims %14 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc18) + %19 = tt.splat %arg5 : i32 -> tensor<64x1xi32> loc(#loc19) + %20 = arith.muli %18, %19 : tensor<64x1xi32> loc(#loc19) + %21 = tt.splat %3 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc20) + %22 = tt.addptr %21, %20 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc20) + %23 = tt.expand_dims %12 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc21) + %24 = tt.broadcast %22 : tensor<64x1x!tt.ptr> -> tensor<64x64x!tt.ptr> loc(#loc22) + %25 = tt.broadcast %23 : tensor<1x64xi32> -> tensor<64x64xi32> loc(#loc22) + %26 = tt.addptr %24, %25 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> loc(#loc22) + %27 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc23) + %28 = tt.splat %arg4 : i32 -> tensor<64x1xi32> loc(#loc24) + %29 = arith.muli %27, %28 : tensor<64x1xi32> loc(#loc24) + %30 = tt.splat %6 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc25) + %31 = tt.addptr %30, %29 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc25) + %32 = tt.expand_dims %17 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc26) + %33 = tt.broadcast %31 : tensor<64x1x!tt.ptr> -> tensor<64x64x!tt.ptr> loc(#loc27) + %34 = tt.broadcast %32 : tensor<1x64xi32> -> tensor<64x64xi32> loc(#loc27) + %35 = tt.addptr %33, %34 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> loc(#loc27) + %36 = arith.muli %18, %28 : tensor<64x1xi32> loc(#loc28) + %37 = tt.splat %8 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc29) + %38 = tt.addptr %37, %36 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc29) + %39 = tt.broadcast %38 : tensor<64x1x!tt.ptr> -> tensor<64x64x!tt.ptr> loc(#loc30) + %40 = tt.addptr %39, %34 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> loc(#loc30) + %41 = arith.addi %arg5, %c63_i32 : i32 loc(#loc53) + %42 = arith.divsi %41, %c64_i32 : i32 loc(#loc54) + %43 = arith.muli %arg4, %c64_i32 : i32 loc(#loc34) + %44 = tt.splat %43 : i32 -> tensor<64x64xi32> loc(#loc35) + %45:3 = scf.for %arg6 = %c0_i32 to %42 step %c1_i32 iter_args(%arg7 = %26, %arg8 = %35, %arg9 = %cst) -> (tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>, tensor<64x64xf32>) : i32 { + %55 = arith.muli %arg6, %c64_i32 : i32 loc(#loc37) + %56 = arith.subi %arg5, %55 : i32 loc(#loc38) + %57 = tt.splat %56 : i32 -> tensor<1x64xi32> loc(#loc39) + %58 = arith.cmpi slt, %23, %57 : tensor<1x64xi32> loc(#loc39) + %59 = tt.splat %56 : i32 -> tensor<64x1xi32> loc(#loc40) + %60 = arith.cmpi slt, %27, %59 : tensor<64x1xi32> loc(#loc40) + %61 = tt.broadcast %58 : tensor<1x64xi1> -> tensor<64x64xi1> loc(#loc41) + %62 = tt.load %arg7, %61, %cst : tensor<64x64x!tt.ptr> loc(#loc41) + %63 = tt.broadcast %60 : tensor<64x1xi1> -> tensor<64x64xi1> loc(#loc42) + %64 = tt.load %arg8, %63, %cst : tensor<64x64x!tt.ptr> loc(#loc42) + %65 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> loc(#loc43) + %66 = tt.addptr %arg8, %44 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> loc(#loc35) + %67 = tt.dot %62, %64, %arg9 : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> loc(#loc44) + scf.yield %65, %66, %67 : tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>, tensor<64x64xf32> loc(#loc45) + } loc(#loc36) + %46 = tt.splat %arg3 : i32 -> tensor<64xi32> loc(#loc46) + %47 = arith.cmpi slt, %14, %46 : tensor<64xi32> loc(#loc46) + %48 = tt.splat %arg4 : i32 -> tensor<64xi32> loc(#loc47) + %49 = arith.cmpi slt, %17, %48 : tensor<64xi32> loc(#loc47) + %50 = tt.expand_dims %47 {axis = 1 : i32} : tensor<64xi1> -> tensor<64x1xi1> loc(#loc48) + %51 = tt.expand_dims %49 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1> loc(#loc49) + %52 = tt.broadcast %50 : tensor<64x1xi1> -> tensor<64x64xi1> loc(#loc50) + %53 = tt.broadcast %51 : tensor<1x64xi1> -> tensor<64x64xi1> loc(#loc50) + %54 = arith.andi %52, %53 : tensor<64x64xi1> loc(#loc50) + tt.store %40, %45#2, %54 : tensor<64x64x!tt.ptr> loc(#loc51) + tt.return loc(#loc52) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":33:26) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":34:17) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":34:21) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":34:9) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":35:17) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":35:21) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":35:9) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":36:21) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":36:9) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":38:25) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":39:25) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":58:21) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":58:43) +#loc15 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":58:30) +#loc16 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":59:21) +#loc17 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":59:30) +#loc18 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":62:24) +#loc19 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":62:35) +#loc20 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":62:17) +#loc21 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":62:46) +#loc22 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":62:39) +#loc23 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":63:24) +#loc24 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":63:35) +#loc25 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":63:17) +#loc26 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":63:46) +#loc27 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":63:39) +#loc28 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":64:35) +#loc29 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":64:17) +#loc30 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":64:39) +#loc31 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":40:22) +#loc32 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":66:27) +#loc33 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":40:28) +#loc34 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":75:27) +#loc35 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":75:18) +#loc36 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":68:19) +#loc37 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":69:43) +#loc38 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":69:39) +#loc39 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":69:35) +#loc40 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":70:35) +#loc41 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":71:20) +#loc42 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":72:20) +#loc43 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":74:18) +#loc44 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":77:23) +#loc45 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":77:8) +#loc46 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":79:55) +#loc47 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":80:55) +#loc48 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":81:20) +#loc49 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":81:38) +#loc50 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":81:31) +#loc51 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":82:24) +#loc52 = loc("/home/zhengyang/git/flagtree/python/test/ops/bmm/bmm_ascend.py":82:4) +#loc53 = loc(callsite(#loc31 at #loc32)) +#loc54 = loc(callsite(#loc33 at #loc32)) diff --git a/python/test/ops/cumsum/cumsum.py b/python/test/ops/cumsum/cumsum.py new file mode 100644 index 000000000..3586c556b --- /dev/null +++ b/python/test/ops/cumsum/cumsum.py @@ -0,0 +1,191 @@ +import math + +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit(do_not_specialize=["part_num"]) +def scan_part_sum_abc_kernel( + inp, + out, + partial_sum, + B, + C, + part_num, + BLOCK_SIZE: tl.constexpr, +): + pid_a = tl.program_id(0) + pid_b = tl.program_id(1) + pid_c = tl.program_id(2) + + a_idx = pid_a + b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + c_idx = pid_c + + offset = a_idx * B * C + b_idx * C + c_idx + base_part_offset = a_idx * part_num * C + c_idx + part_offset = base_part_offset + pid_b * C + + mask = b_idx < B + inp_ptrs = inp + offset + inp_vals = tl.load(inp_ptrs, mask=mask) + if ( + tl.constexpr(inp_vals.dtype.is_int64()) + or tl.constexpr(inp_vals.dtype.is_uint64()) + ) or tl.constexpr(inp_vals.dtype.is_fp64()): + inp_vals = inp_vals + elif tl.constexpr(inp_vals.dtype.is_int()): + inp_vals = inp_vals.to(tl.int32) + else: + inp_vals = inp_vals.to(tl.float32) + result = tl.cumsum(inp_vals, axis=0) + + part_sum_via_sum = tl.sum(inp_vals) + + out_ptrs = out + offset + tl.store(out_ptrs, result, mask=mask) + + partial_sum_ptrs = partial_sum + part_offset + tl.store(partial_sum_ptrs, part_sum_via_sum) + + +@triton.jit(do_not_specialize=["part_num"]) +def add_base_sum_abc_kernel( + out, + partial_sum, + B, + C, + part_num, + BLOCK_SIZE: tl.constexpr, +): + pid_a = tl.program_id(0) + pid_b = tl.program_id(1) + pid_c = tl.program_id(2) + + a_idx = pid_a + b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + c_idx = pid_c + + base_offset = a_idx * B * C + c_idx + offset = base_offset + b_idx * C + base_part_offset = a_idx * part_num * C + c_idx + last_part_offset = base_part_offset + (pid_b - 1) * C + + mask = b_idx < B + out_ptrs = out + offset + out_vals = tl.load(out_ptrs, mask=mask) + + if pid_b > 0: + partial_sum_ptrs = partial_sum + last_part_offset + last_part_sum_via_sum = tl.load(partial_sum_ptrs) + + final_vals = out_vals + last_part_sum_via_sum + tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) + + +def scan_then_fan(inp, out, A, B, C, dtype): + # TODO(all): tune on target board + BLOCK_SIZE = 1024 + if B <= 1024 * 4: + BLOCK_SIZE = triton.next_power_of_2(B) + part_num = math.ceil(B / BLOCK_SIZE) + partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) + + grid = (A, part_num, C) + with torch_device_fn.device(inp.device): + scan_part_sum_abc_kernel[grid]( + inp, out, partial_sum, B, C, part_num, BLOCK_SIZE + ) + + if part_num >= 2: + scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) + with torch_device_fn.device(inp.device): + add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) + + +def cumsum_wrapper(inp, dim=1, dtype=None, out=None): + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + shape = inp.shape + dim = dim % inp.ndim + M = 1 + N = shape[dim] + for i in range(dim): + M *= shape[i] + inp = inp.contiguous() + K = inp.numel() // M // N + + if dtype is None: + dtype = inp.dtype + if dtype is torch.bool: + dtype = torch.int64 + if out is None: + out = torch.empty_like(inp, dtype=dtype) + + compute_dtype = out.dtype + if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: + compute_dtype = torch.float32 + + if M == 1 and K == 1: + pass + else: + scan_then_fan(inp, out, M, N, K, compute_dtype) + return out + + +def cumsum(inp, dim=1, *, dtype=None): + return cumsum_wrapper(inp, dim, dtype) + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + shape = (2, 4096) + dtype = torch.float32 + dim = 1 + + # inp + inp = torch.randn(shape, dtype=dtype, device=device) + ref_inp = inp.cpu() + + # op + ref_out = torch.cumsum(ref_inp, dim=dim) + res_out = cumsum(inp, dim=dim) + check("value", ref_out, res_out, reduce_dim=shape[dim]) diff --git a/python/test/ops/cumsum/triton-ascend/scan_part_sum_abc_kernel.ttadapter b/python/test/ops/cumsum/triton-ascend/scan_part_sum_abc_kernel.ttadapter new file mode 100644 index 000000000..46e509e4a --- /dev/null +++ b/python/test/ops/cumsum/triton-ascend/scan_part_sum_abc_kernel.ttadapter @@ -0,0 +1,55 @@ +module { + func.func private @triton_cumsum_0(tensor<4096xf32>, i32, i1) -> tensor<4096xf32> + func.func @scan_part_sum_abc_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c0_i32 = arith.constant 0 : i32 + %false = arith.constant false + %cst = arith.constant 0.000000e+00 : f32 + %c4096 = arith.constant 4096 : index + %c4096_i32 = arith.constant 4096 : i32 + %0 = arith.muli %arg11, %c4096_i32 : i32 + %1 = arith.muli %arg10, %arg5 : i32 + %2 = arith.muli %arg10, %arg6 : i32 + %3 = arith.addi %2, %arg12 : i32 + %4 = arith.addi %3, %arg11 : i32 + %5 = arith.index_cast %1 : i32 to index + %6 = arith.index_cast %0 : i32 to index + %7 = arith.addi %5, %6 : index + %8 = arith.index_cast %arg12 : i32 to index + %9 = arith.addi %7, %8 : index + %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%9], sizes: [4096], strides: [1] : memref to memref<4096xf32, strided<[1], offset: ?>> + %alloc = memref.alloc() : memref<4096xf32> + %10 = arith.addi %6, %c4096 : index + %11 = arith.index_cast %arg5 : i32 to index + %12 = arith.maxsi %6, %11 : index + %13 = arith.minsi %10, %12 : index + %14 = arith.subi %13, %6 : index + %15 = arith.cmpi slt, %14, %c4096 : index + scf.if %15 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<4096xf32>) + } + %subview = memref.subview %reinterpret_cast[0] [%14] [1] : memref<4096xf32, strided<[1], offset: ?>> to memref> + %subview_0 = memref.subview %alloc[0] [%14] [1] : memref<4096xf32> to memref> + memref.copy %subview, %subview_0 : memref> to memref> + %16 = bufferization.to_tensor %alloc restrict writable : memref<4096xf32> + %17 = call @triton_cumsum_0(%16, %c0_i32, %false) : (tensor<4096xf32>, i32, i1) -> tensor<4096xf32> + %18 = bufferization.alloc_tensor() : tensor + %19 = linalg.fill ins(%cst : f32) outs(%18 : tensor) -> tensor + %reduced = linalg.reduce ins(%16 : tensor<4096xf32>) outs(%19 : tensor) dimensions = [0] + (%in: f32, %init: f32) { + %23 = arith.addf %in, %init : f32 + linalg.yield %23 : f32 + } + %extracted = tensor.extract %reduced[] : tensor + %reinterpret_cast_1 = memref.reinterpret_cast %arg3 to offset: [%9], sizes: [4096], strides: [1] : memref to memref<4096xf32, strided<[1], offset: ?>> + %extracted_slice = tensor.extract_slice %17[0] [%14] [1] : tensor<4096xf32> to tensor + %subview_2 = memref.subview %reinterpret_cast_1[0] [%14] [1] : memref<4096xf32, strided<[1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview_2 : (tensor, memref>) -> () + %20 = arith.index_cast %4 : i32 to index + %21 = tensor.empty() : tensor<1xf32> + %22 = linalg.fill ins(%extracted : f32) outs(%21 : tensor<1xf32>) -> tensor<1xf32> + %reinterpret_cast_3 = memref.reinterpret_cast %arg4 to offset: [%20], sizes: [1], strides: [1] : memref to memref<1xf32, strided<[1], offset: ?>> + bufferization.materialize_in_destination %22 in writable %reinterpret_cast_3 : (tensor<1xf32>, memref<1xf32, strided<[1], offset: ?>>) -> () + return + } +} + diff --git a/python/test/ops/cumsum/triton-ascend/scan_part_sum_abc_kernel.ttir b/python/test/ops/cumsum/triton-ascend/scan_part_sum_abc_kernel.ttir new file mode 100644 index 000000000..b8b57e787 --- /dev/null +++ b/python/test/ops/cumsum/triton-ascend/scan_part_sum_abc_kernel.ttir @@ -0,0 +1,77 @@ +#loc = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":19:0) +#loc1 = loc(unknown) +#loc18 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":52:23) +#loc21 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":54:30) +#loc28 = loc(callsite(#loc1 at #loc18)) +#loc31 = loc(callsite(#loc1 at #loc21)) +module { + tt.func public @scan_part_sum_abc_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":19:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":19:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":19:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":19:0), %arg4: i32 loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":19:0)) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<4096xf32> loc(#loc1) + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = tt.get_program_id y : i32 loc(#loc3) + %2 = tt.get_program_id z : i32 loc(#loc4) + %3 = arith.muli %1, %c4096_i32 : i32 loc(#loc5) + %4 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> loc(#loc6) + %5 = tt.splat %3 : i32 -> tensor<4096xi32> loc(#loc7) + %6 = arith.addi %5, %4 : tensor<4096xi32> loc(#loc7) + %7 = arith.muli %0, %arg3 : i32 loc(#loc8) + %8 = tt.splat %7 : i32 -> tensor<4096xi32> loc(#loc9) + %9 = arith.addi %8, %6 : tensor<4096xi32> loc(#loc9) + %10 = tt.splat %2 : i32 -> tensor<4096xi32> loc(#loc10) + %11 = arith.addi %9, %10 : tensor<4096xi32> loc(#loc10) + %12 = arith.muli %0, %arg4 : i32 loc(#loc11) + %13 = arith.addi %12, %2 : i32 loc(#loc12) + %14 = arith.addi %13, %1 : i32 loc(#loc13) + %15 = tt.splat %arg3 : i32 -> tensor<4096xi32> loc(#loc14) + %16 = arith.cmpi slt, %6, %15 : tensor<4096xi32> loc(#loc14) + %17 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> loc(#loc15) + %18 = tt.addptr %17, %11 : tensor<4096x!tt.ptr>, tensor<4096xi32> loc(#loc15) + %19 = tt.load %18, %16, %cst : tensor<4096x!tt.ptr> loc(#loc16) + %20 = "tt.scan"(%19) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg5: f32 loc(callsite(#loc1 at #loc18)), %arg6: f32 loc(callsite(#loc1 at #loc18))): + %25 = arith.addf %arg5, %arg6 : f32 loc(#loc33) + tt.scan.return %25 : f32 loc(#loc27) + }) : (tensor<4096xf32>) -> tensor<4096xf32> loc(#loc27) + %21 = "tt.reduce"(%19) <{axis = 0 : i32}> ({ + ^bb0(%arg5: f32 loc(callsite(#loc1 at #loc21)), %arg6: f32 loc(callsite(#loc1 at #loc21))): + %25 = arith.addf %arg5, %arg6 : f32 loc(#loc34) + tt.reduce.return %25 : f32 loc(#loc30) + }) : (tensor<4096xf32>) -> f32 loc(#loc30) + %22 = tt.splat %arg1 : !tt.ptr -> tensor<4096x!tt.ptr> loc(#loc22) + %23 = tt.addptr %22, %11 : tensor<4096x!tt.ptr>, tensor<4096xi32> loc(#loc22) + tt.store %23, %20, %16 : tensor<4096x!tt.ptr> loc(#loc23) + %24 = tt.addptr %arg2, %14 : !tt.ptr, i32 loc(#loc24) + tt.store %24, %21 : !tt.ptr loc(#loc25) + tt.return loc(#loc26) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":28:26) +#loc3 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":29:26) +#loc4 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":30:26) +#loc5 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":33:20) +#loc6 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":33:46) +#loc7 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":33:33) +#loc8 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":36:21) +#loc9 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":36:29) +#loc10 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":36:41) +#loc11 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":37:31) +#loc12 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":37:46) +#loc13 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":38:37) +#loc14 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":40:19) +#loc15 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":41:21) +#loc16 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":42:23) +#loc17 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":299:60) +#loc19 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":256:15) +#loc20 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":267:36) +#loc22 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":56:21) +#loc23 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":57:23) +#loc24 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":59:37) +#loc25 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":60:31) +#loc26 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/cumsum/cumsum.py":60:4) +#loc27 = loc(callsite(#loc17 at #loc18)) +#loc29 = loc(callsite(#loc19 at #loc17)) +#loc30 = loc(callsite(#loc20 at #loc21)) +#loc32 = loc(callsite(#loc19 at #loc20)) +#loc33 = loc(callsite(#loc29 at #loc18)) +#loc34 = loc(callsite(#loc32 at #loc21)) diff --git a/python/test/ops/min_dim/min_dim.py b/python/test/ops/min_dim/min_dim.py new file mode 100644 index 000000000..d5badd1b8 --- /dev/null +++ b/python/test/ops/min_dim/min_dim.py @@ -0,0 +1,126 @@ +from collections import namedtuple + +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def get_dtype_max(dtype: tl.constexpr): + """get a value which is greater that all other values of that dtype""" + # extract the tl.dtype from tl.constexpr so as to use its methods + dtype_ = dtype.value + if dtype_.is_floating(): + value: tl.constexpr = float("inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2 ** (width - 1) - 1 + return value + if dtype_.is_int_unsigned(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2**width - 1 + return value + + +@triton.jit +def min_kernel( + inp, + out_value, + out_index, + M, # 1 + N, # 32 + BLOCK_M: tl.constexpr = 8, + BLOCK_N: tl.constexpr = 256, +): + # 1. prepare offset + pid_m = tl.program_id(0) + m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + dtype = inp.type.element_ty + # you just cannot create a function that return a tl.dtype in triton lang + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + max_value = get_dtype_max(dtype) + min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) + argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) + + # 2. for + for start_n in range(0, N, BLOCK_N): + n_offset = start_n + tl.arange(0, BLOCK_N) + offset = m_offset[:, None] * N + n_offset[None, :] + mask = m_offset[:, None] < M and n_offset[None, :] < N + inp_ptrs = inp + offset + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) + local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) + update = local_min < min_values + min_values = tl.where(update, local_min, min_values) + argmin_values = tl.where(update, start_n + local_argmin, argmin_values) + + # 3. store + offset_index = m_offset + out_value_ptrs = out_value + offset_index + out_index_ptrs = out_index + offset_index + mask1 = m_offset < M + tl.store(out_value_ptrs, min_values, mask=mask1) + tl.store(out_index_ptrs, argmin_values, mask=mask1) + + +def min_dim(inp, dim=None, keepdim=False): + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + shape = list(inp.shape) + dim = dim % inp.ndim + N = shape[dim] + shape[dim] = 1 + M = inp.numel() // N + + out_value = torch.empty(shape, dtype=inp.dtype, device=inp.device) + out_index = torch.empty(shape, dtype=torch.int64, device=inp.device) + + if not keepdim: + out_value = torch.squeeze(out_value, dim) + out_index = torch.squeeze(out_index, dim) + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + with torch_device_fn.device(inp.device): + min_kernel[grid](inp, out_value, out_index, M, N) + Min_out = namedtuple("min", ["values", "indices"]) + out = Min_out(values=out_value, indices=out_index) + return out + + +if __name__ == "__main__": + # param + shape = (1, 32) + dim = 1 + keepdim = True + + # inp + inp = torch.randn(shape, dtype=torch.float32, device=device) + ref_inp = inp.cpu() + + # op + ref_out_value, ref_out_index = torch.min(ref_inp, dim=dim, keepdim=keepdim) + res_out_value, res_out_index = min_dim(inp, dim=dim, keepdim=keepdim) + + # check + res_out_value = res_out_value.cpu() + print( + f"The maximum difference out value between torch and triton is " + f"{torch.max(torch.abs(ref_out_value - res_out_value))}" + ) + assert torch.allclose(res_out_value, ref_out_value), (res_out_value, ref_out_value) + res_out_index = res_out_index.cpu() + print( + f"The maximum difference out index between torch and triton is " + f"{torch.max(torch.abs(ref_out_index - res_out_index))}" + ) + assert torch.allclose(res_out_index, ref_out_index), (res_out_index, ref_out_index) diff --git a/python/test/ops/min_dim/min_dim_ascend_perf.py b/python/test/ops/min_dim/min_dim_ascend_perf.py new file mode 100644 index 000000000..c5951d40a --- /dev/null +++ b/python/test/ops/min_dim/min_dim_ascend_perf.py @@ -0,0 +1,138 @@ +from collections import namedtuple +import math + +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def get_dtype_max(dtype: tl.constexpr): + """get a value which is greater that all other values of that dtype""" + # extract the tl.dtype from tl.constexpr so as to use its methods + dtype_ = dtype.value + if dtype_.is_floating(): + value: tl.constexpr = float("inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2 ** (width - 1) - 1 + return value + if dtype_.is_int_unsigned(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2**width - 1 + return value + + +@triton.jit +def min_kernel( + inp, + out_value, + out_index, + M, # 1 + N, # 32 + K, + BLOCK_M: tl.constexpr = 8, + BLOCK_N: tl.constexpr = 256, +): + # set offset + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + dtype = inp.type.element_ty + # you just cannot create a function that return a tl.dtype in triton lang + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + max_value = get_dtype_max(dtype) + min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) + argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) + for start_n in range(0, N, BLOCK_N): + n_offset = start_n + tl.arange(0, BLOCK_N) + offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k + mask = m_offset[:, None] < M and n_offset[None, :] < N + inp_ptrs = inp + offset + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) + if dtype is tl.int64: + inp_vals = tl.where(mask, inp_vals, max_value) + local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) + # if return indices is not supported, call a tl.argmax in addition + # local_argmin = tl.argmin(inp_vals, 1) + update = local_min < min_values + min_values = tl.where(update, local_min, min_values) + argmin_values = tl.where(update, start_n + local_argmin, argmin_values) + + offset_index = m_offset * K + pid_k + out_value_ptrs = out_value + offset_index + out_index_ptrs = out_index + offset_index + mask1 = m_offset < M + tl.store(out_value_ptrs, min_values, mask=mask1) + tl.store(out_index_ptrs, argmin_values, mask=mask1) + + +def min_dim(inp, dim=None, keepdim=False): + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + shape = inp.shape + dim = dim % inp.ndim + N = shape[dim] + M = math.prod(shape[:dim]) + K = inp.numel() // M // N + + inp = inp.contiguous() + + shape_list = list(shape) + shape_list[dim] = 1 + out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device) + out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) + + if not keepdim: + out_value = torch.squeeze(out_value, dim) + out_index = torch.squeeze(out_index, dim) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + K, + ) + with torch_device_fn.device(inp.device): + min_kernel[grid](inp, out_value, out_index, M, N, K) + Min_out = namedtuple("min", ["values", "indices"]) + out = Min_out(values=out_value, indices=out_index) + return out + + +if __name__ == "__main__": + # param + shape = (2, 16, 4) + dim = 1 + keepdim = True + + # inp + inp = torch.randn(shape, dtype=torch.float32, device=device) + ref_inp = inp.cpu() + + # op + ref_out_value, ref_out_index = torch.min(ref_inp, dim=dim, keepdim=keepdim) + res_out_value, res_out_index = min_dim(inp, dim=dim, keepdim=keepdim) + + # check + res_out_value = res_out_value.cpu() + print( + f"The maximum difference out value between torch and triton is " + f"{torch.max(torch.abs(ref_out_value - res_out_value))}" + ) + assert torch.allclose(res_out_value, ref_out_value), (res_out_value, ref_out_value) + res_out_index = res_out_index.cpu() + print( + f"The maximum difference out index between torch and triton is " + f"{torch.max(torch.abs(ref_out_index - res_out_index))}" + ) + assert torch.allclose(res_out_index, ref_out_index), (res_out_index, ref_out_index) diff --git a/python/test/ops/min_dim/triton-ascend-perf/min_kernel.ttadapter b/python/test/ops/min_dim/triton-ascend-perf/min_kernel.ttadapter new file mode 100644 index 000000000..d4dda7b76 --- /dev/null +++ b/python/test/ops/min_dim/triton-ascend-perf/min_kernel.ttadapter @@ -0,0 +1,103 @@ +#map = affine_map<(d0) -> (d0)> +module { + func.func @min_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c-1_i32 = arith.constant -1 : i32 + %c256 = arith.constant 256 : index + %c8 = arith.constant 8 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst = arith.constant 0x7F800000 : f32 + %0 = tensor.empty() : tensor<8xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8xf32>) -> tensor<8xf32> + %2 = tensor.empty() : tensor<8xi64> + %3 = linalg.fill ins(%c0_i64 : i64) outs(%2 : tensor<8xi64>) -> tensor<8xi64> + %4 = arith.muli %arg11, %c8_i32 : i32 + %5 = tensor.empty() : tensor<256xi32> + %6 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%5 : tensor<256xi32>) { + ^bb0(%out: i32): + %19 = linalg.index 0 : index + %20 = arith.index_cast %19 : index to i32 + linalg.yield %20 : i32 + } -> tensor<256xi32> + %7 = tensor.empty() : tensor<8x256xi32> + %broadcasted = linalg.broadcast ins(%6 : tensor<256xi32>) outs(%7 : tensor<8x256xi32>) dimensions = [0] + %8:2 = scf.for %arg14 = %c0_i32 to %arg6 step %c256_i32 iter_args(%arg15 = %1, %arg16 = %3) -> (tensor<8xf32>, tensor<8xi64>) : i32 { + %19 = arith.index_cast %4 : i32 to index + %20 = arith.index_cast %arg6 : i32 to index + %21 = arith.muli %19, %20 : index + %22 = arith.index_cast %arg7 : i32 to index + %23 = arith.muli %21, %22 : index + %24 = arith.muli %20, %22 : index + %25 = arith.index_cast %arg14 : i32 to index + %26 = arith.muli %25, %22 : index + %27 = arith.index_cast %arg12 : i32 to index + %28 = arith.addi %23, %27 : index + %29 = arith.addi %28, %26 : index + %reinterpret_cast_3 = memref.reinterpret_cast %arg2 to offset: [%29], sizes: [8, 256], strides: [%24, %22] : memref to memref<8x256xf32, strided<[?, ?], offset: ?>> + %alloc = memref.alloc() : memref<8x256xf32> + %30 = arith.addi %19, %c8 : index + %31 = arith.index_cast %arg5 : i32 to index + %32 = arith.maxsi %19, %31 : index + %33 = arith.minsi %30, %32 : index + %34 = arith.subi %33, %19 : index + %35 = arith.addi %25, %c256 : index + %36 = arith.maxsi %25, %20 : index + %37 = arith.minsi %35, %36 : index + %38 = arith.subi %37, %25 : index + %39 = arith.minsi %34, %c8 : index + %40 = arith.minsi %38, %c256 : index + %41 = arith.cmpi slt, %39, %c8 : index + %42 = arith.cmpi slt, %40, %c256 : index + %43 = arith.ori %41, %42 : i1 + scf.if %43 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<8x256xf32>) + } + %subview_4 = memref.subview %reinterpret_cast_3[0, 0] [%39, %40] [1, 1] : memref<8x256xf32, strided<[?, ?], offset: ?>> to memref> + %subview_5 = memref.subview %alloc[0, 0] [%39, %40] [1, 1] : memref<8x256xf32> to memref> + memref.copy %subview_4, %subview_5 : memref> to memref> + %44 = bufferization.to_tensor %alloc restrict writable : memref<8x256xf32> + %45 = tensor.empty() : tensor<8xi32> + %46 = linalg.fill ins(%c-1_i32 : i32) outs(%45 : tensor<8xi32>) -> tensor<8xi32> + %reduced:2 = linalg.reduce ins(%44, %broadcasted : tensor<8x256xf32>, tensor<8x256xi32>) outs(%1, %46 : tensor<8xf32>, tensor<8xi32>) dimensions = [1] {reduce_mode = "min_with_index"} + (%in: f32, %in_6: i32, %init: f32, %init_7: i32) { + %53 = arith.cmpf olt, %in, %init : f32 + %54 = arith.cmpf oeq, %in, %init : f32 + %55 = arith.cmpi slt, %in_6, %init_7 : i32 + %56 = arith.andi %54, %55 : i1 + %57 = arith.ori %53, %56 : i1 + %58 = arith.select %57, %in, %init : f32 + %59 = arith.select %57, %in_6, %init_7 : i32 + linalg.yield %58, %59 : f32, i32 + } + %47 = arith.cmpf olt, %reduced#0, %arg15 : tensor<8xf32> + %48 = arith.select %47, %reduced#0, %arg15 : tensor<8xi1>, tensor<8xf32> + %49 = linalg.fill ins(%arg14 : i32) outs(%45 : tensor<8xi32>) -> tensor<8xi32> + %50 = arith.addi %49, %reduced#1 : tensor<8xi32> + %51 = arith.extsi %50 : tensor<8xi32> to tensor<8xi64> + %52 = arith.select %47, %51, %arg16 : tensor<8xi1>, tensor<8xi64> + scf.yield %48, %52 : tensor<8xf32>, tensor<8xi64> + } + %9 = arith.index_cast %4 : i32 to index + %10 = arith.index_cast %arg7 : i32 to index + %11 = arith.muli %9, %10 : index + %12 = arith.index_cast %arg12 : i32 to index + %13 = arith.addi %11, %12 : index + %reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%13], sizes: [8], strides: [%10] : memref to memref<8xf32, strided<[?], offset: ?>> + %reinterpret_cast_0 = memref.reinterpret_cast %arg4 to offset: [%13], sizes: [8], strides: [%10] : memref to memref<8xi64, strided<[?], offset: ?>> + %14 = arith.addi %9, %c8 : index + %15 = arith.index_cast %arg5 : i32 to index + %16 = arith.maxsi %9, %15 : index + %17 = arith.minsi %14, %16 : index + %18 = arith.subi %17, %9 : index + %extracted_slice = tensor.extract_slice %8#0[0] [%18] [1] : tensor<8xf32> to tensor + %subview = memref.subview %reinterpret_cast[0] [%18] [1] : memref<8xf32, strided<[?], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor, memref>) -> () + %extracted_slice_1 = tensor.extract_slice %8#1[0] [%18] [1] : tensor<8xi64> to tensor + %subview_2 = memref.subview %reinterpret_cast_0[0] [%18] [1] : memref<8xi64, strided<[?], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice_1 in writable %subview_2 : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/min_dim/triton-ascend-perf/min_kernel.ttir b/python/test/ops/min_dim/triton-ascend-perf/min_kernel.ttir new file mode 100644 index 000000000..3ee56c742 --- /dev/null +++ b/python/test/ops/min_dim/triton-ascend-perf/min_kernel.ttir @@ -0,0 +1,145 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":38:0) +#loc1 = loc(unknown) +#loc19 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":67:51) +#loc46 = loc(callsite(#loc1 at #loc19)) +module { + tt.func public @min_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":38:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":38:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":38:0), %arg3: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":38:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":38:0), %arg5: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":38:0)) attributes {noinline = false} { + %cst = arith.constant dense<0x7F800000> : tensor<8x256xf32> loc(#loc1) + %cst_0 = arith.constant dense<0x7F800000> : tensor<8xf32> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_1 = arith.constant dense<0> : tensor<8xi64> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = tt.get_program_id y : i32 loc(#loc3) + %2 = arith.muli %0, %c8_i32 : i32 loc(#loc4) + %3 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc5) + %4 = tt.splat %2 : i32 -> tensor<8xi32> loc(#loc6) + %5 = arith.addi %4, %3 : tensor<8xi32> loc(#loc6) + %6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc7) + %7 = tt.expand_dims %5 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc8) + %8 = tt.splat %arg4 : i32 -> tensor<8x1xi32> loc(#loc9) + %9 = arith.muli %7, %8 : tensor<8x1xi32> loc(#loc9) + %10 = tt.splat %arg5 : i32 -> tensor<8x1xi32> loc(#loc10) + %11 = arith.muli %9, %10 : tensor<8x1xi32> loc(#loc10) + %12 = tt.splat %arg5 : i32 -> tensor<1x256xi32> loc(#loc11) + %13 = tt.broadcast %11 : tensor<8x1xi32> -> tensor<8x256xi32> loc(#loc12) + %14 = tt.splat %1 : i32 -> tensor<8x256xi32> loc(#loc13) + %15 = tt.splat %arg3 : i32 -> tensor<8x1xi32> loc(#loc14) + %16 = arith.cmpi slt, %7, %15 : tensor<8x1xi32> loc(#loc14) + %17 = tt.splat %arg4 : i32 -> tensor<1x256xi32> loc(#loc15) + %18 = tt.broadcast %16 : tensor<8x1xi1> -> tensor<8x256xi1> loc(#loc16) + %19 = tt.splat %arg0 : !tt.ptr -> tensor<8x256x!tt.ptr> loc(#loc17) + %20 = tt.expand_dims %6 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc45) + %21 = tt.broadcast %20 : tensor<1x256xi32> -> tensor<8x256xi32> loc(#loc45) + %22:2 = scf.for %arg6 = %c0_i32 to %arg4 step %c256_i32 iter_args(%arg7 = %cst_0, %arg8 = %cst_1) -> (tensor<8xf32>, tensor<8xi64>) : i32 { + %33 = tt.splat %arg6 : i32 -> tensor<256xi32> loc(#loc21) + %34 = arith.addi %33, %6 : tensor<256xi32> loc(#loc21) + %35 = tt.expand_dims %34 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc22) + %36 = arith.muli %35, %12 : tensor<1x256xi32> loc(#loc11) + %37 = tt.broadcast %36 : tensor<1x256xi32> -> tensor<8x256xi32> loc(#loc12) + %38 = arith.addi %13, %37 : tensor<8x256xi32> loc(#loc12) + %39 = arith.addi %38, %14 : tensor<8x256xi32> loc(#loc13) + %40 = arith.cmpi slt, %35, %17 : tensor<1x256xi32> loc(#loc15) + %41 = tt.broadcast %40 : tensor<1x256xi1> -> tensor<8x256xi1> loc(#loc16) + %42 = arith.andi %18, %41 : tensor<8x256xi1> loc(#loc16) + %43 = tt.addptr %19, %39 : tensor<8x256x!tt.ptr>, tensor<8x256xi32> loc(#loc17) + %44 = tt.load %43, %42, %cst : tensor<8x256x!tt.ptr> loc(#loc23) + %45:2 = "tt.reduce"(%44, %21) <{axis = 1 : i32}> ({ + ^bb0(%arg9: f32 loc(callsite(#loc1 at #loc19)), %arg10: i32 loc(callsite(#loc1 at #loc19)), %arg11: f32 loc(callsite(#loc1 at #loc19)), %arg12: i32 loc(callsite(#loc1 at #loc19))): + %52 = arith.cmpf oeq, %arg9, %arg11 : f32 loc(#loc61) + %53 = arith.cmpi slt, %arg10, %arg12 : i32 loc(#loc62) + %54 = arith.andi %52, %53 : i1 loc(#loc63) + %55 = arith.cmpf olt, %arg9, %arg11 : f32 loc(#loc64) + %56 = arith.ori %55, %54 : i1 loc(#loc65) + %57 = arith.select %56, %arg9, %arg11 : f32 loc(#loc66) + %58 = arith.select %56, %arg10, %arg12 : i32 loc(#loc67) + tt.reduce.return %57, %58 : f32, i32 loc(#loc45) + }) : (tensor<8x256xf32>, tensor<8x256xi32>) -> (tensor<8xf32>, tensor<8xi32>) loc(#loc45) + %46 = arith.cmpf olt, %45#0, %arg7 : tensor<8xf32> loc(#loc32) + %47 = arith.select %46, %45#0, %arg7 : tensor<8xi1>, tensor<8xf32> loc(#loc33) + %48 = tt.splat %arg6 : i32 -> tensor<8xi32> loc(#loc34) + %49 = arith.addi %48, %45#1 : tensor<8xi32> loc(#loc34) + %50 = arith.extsi %49 : tensor<8xi32> to tensor<8xi64> loc(#loc35) + %51 = arith.select %46, %50, %arg8 : tensor<8xi1>, tensor<8xi64> loc(#loc35) + scf.yield %47, %51 : tensor<8xf32>, tensor<8xi64> loc(#loc36) + } loc(#loc20) + %23 = tt.splat %arg5 : i32 -> tensor<8xi32> loc(#loc37) + %24 = arith.muli %5, %23 : tensor<8xi32> loc(#loc37) + %25 = tt.splat %1 : i32 -> tensor<8xi32> loc(#loc38) + %26 = arith.addi %24, %25 : tensor<8xi32> loc(#loc38) + %27 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> loc(#loc39) + %28 = tt.addptr %27, %26 : tensor<8x!tt.ptr>, tensor<8xi32> loc(#loc39) + %29 = tt.splat %arg2 : !tt.ptr -> tensor<8x!tt.ptr> loc(#loc40) + %30 = tt.addptr %29, %26 : tensor<8x!tt.ptr>, tensor<8xi32> loc(#loc40) + %31 = tt.splat %arg3 : i32 -> tensor<8xi32> loc(#loc41) + %32 = arith.cmpi slt, %5, %31 : tensor<8xi32> loc(#loc41) + tt.store %28, %22#0, %32 : tensor<8x!tt.ptr> loc(#loc42) + tt.store %30, %22#1, %32 : tensor<8x!tt.ptr> loc(#loc43) + tt.return loc(#loc44) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":49:26) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":50:26) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":51:23) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":51:46) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":51:33) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":60:42) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":61:26) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":61:37) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":61:41) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":61:65) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":61:45) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":61:69) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":62:35) +#loc15 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":62:61) +#loc16 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":62:41) +#loc17 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":63:25) +#loc18 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":233:58) +#loc20 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":59:31) +#loc21 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":60:29) +#loc22 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":61:54) +#loc23 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":64:27) +#loc24 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:24) +#loc25 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":212:59) +#loc26 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:44) +#loc27 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:35) +#loc28 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":204:18) +#loc29 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":204:28) +#loc30 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":205:39) +#loc31 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":206:39) +#loc32 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":70:29) +#loc33 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":71:49) +#loc34 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":72:51) +#loc35 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":72:65) +#loc36 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":72:8) +#loc37 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":74:30) +#loc38 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":74:34) +#loc39 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":75:33) +#loc40 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":76:33) +#loc41 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":77:23) +#loc42 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":78:29) +#loc43 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":79:29) +#loc44 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim_ascend_perf.py":79:4) +#loc45 = loc(callsite(#loc18 at #loc19)) +#loc47 = loc(callsite(#loc24 at #loc25)) +#loc48 = loc(callsite(#loc26 at #loc25)) +#loc49 = loc(callsite(#loc27 at #loc25)) +#loc50 = loc(callsite(#loc28 at #loc25)) +#loc51 = loc(callsite(#loc29 at #loc25)) +#loc52 = loc(callsite(#loc30 at #loc25)) +#loc53 = loc(callsite(#loc31 at #loc25)) +#loc54 = loc(callsite(#loc47 at #loc18)) +#loc55 = loc(callsite(#loc48 at #loc18)) +#loc56 = loc(callsite(#loc49 at #loc18)) +#loc57 = loc(callsite(#loc50 at #loc18)) +#loc58 = loc(callsite(#loc51 at #loc18)) +#loc59 = loc(callsite(#loc52 at #loc18)) +#loc60 = loc(callsite(#loc53 at #loc18)) +#loc61 = loc(callsite(#loc54 at #loc19)) +#loc62 = loc(callsite(#loc55 at #loc19)) +#loc63 = loc(callsite(#loc56 at #loc19)) +#loc64 = loc(callsite(#loc57 at #loc19)) +#loc65 = loc(callsite(#loc58 at #loc19)) +#loc66 = loc(callsite(#loc59 at #loc19)) +#loc67 = loc(callsite(#loc60 at #loc19)) diff --git a/python/test/ops/min_dim/triton-ascend/min_kernel.ttadapter b/python/test/ops/min_dim/triton-ascend/min_kernel.ttadapter new file mode 100644 index 000000000..34ae0f4ed --- /dev/null +++ b/python/test/ops/min_dim/triton-ascend/min_kernel.ttadapter @@ -0,0 +1,92 @@ +#map = affine_map<(d0) -> (d0)> +module { + func.func @min_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c-1_i32 = arith.constant -1 : i32 + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst = arith.constant 0x7F800000 : f32 + %0 = tensor.empty() : tensor<8xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8xf32>) -> tensor<8xf32> + %2 = tensor.empty() : tensor<8xi64> + %3 = linalg.fill ins(%c0_i64 : i64) outs(%2 : tensor<8xi64>) -> tensor<8xi64> + %4 = arith.muli %arg9, %c8_i32 : i32 + %5 = tensor.empty() : tensor<256xi32> + %6 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%5 : tensor<256xi32>) { + ^bb0(%out: i32): + %14 = linalg.index 0 : index + %15 = arith.index_cast %14 : index to i32 + linalg.yield %15 : i32 + } -> tensor<256xi32> + %7 = tensor.empty() : tensor<8x256xi32> + %broadcasted = linalg.broadcast ins(%6 : tensor<256xi32>) outs(%7 : tensor<8x256xi32>) dimensions = [0] + %8:2 = scf.for %arg12 = %c0_i32 to %arg5 step %c256_i32 iter_args(%arg13 = %1, %arg14 = %3) -> (tensor<8xf32>, tensor<8xi64>) : i32 { + %14 = arith.index_cast %4 : i32 to index + %15 = arith.index_cast %arg5 : i32 to index + %16 = arith.muli %14, %15 : index + %17 = arith.index_cast %arg12 : i32 to index + %18 = arith.addi %16, %17 : index + %reinterpret_cast_3 = memref.reinterpret_cast %arg2 to offset: [%18], sizes: [8, 256], strides: [%15, 1] : memref to memref<8x256xf32, strided<[?, 1], offset: ?>> + %alloc = memref.alloc() : memref<8x256xf32> + %19 = arith.addi %14, %c8 : index + %20 = arith.maxsi %14, %c1 : index + %21 = arith.minsi %19, %20 : index + %22 = arith.subi %21, %14 : index + %23 = arith.addi %17, %c256 : index + %24 = arith.maxsi %17, %15 : index + %25 = arith.minsi %23, %24 : index + %26 = arith.subi %25, %17 : index + %27 = arith.minsi %22, %c8 : index + %28 = arith.minsi %26, %c256 : index + %29 = arith.cmpi slt, %27, %c8 : index + %30 = arith.cmpi slt, %28, %c256 : index + %31 = arith.ori %29, %30 : i1 + scf.if %31 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<8x256xf32>) + } + %subview_4 = memref.subview %reinterpret_cast_3[0, 0] [%27, %28] [1, 1] : memref<8x256xf32, strided<[?, 1], offset: ?>> to memref> + %subview_5 = memref.subview %alloc[0, 0] [%27, %28] [1, 1] : memref<8x256xf32> to memref> + memref.copy %subview_4, %subview_5 : memref> to memref> + %32 = bufferization.to_tensor %alloc restrict writable : memref<8x256xf32> + %33 = tensor.empty() : tensor<8xi32> + %34 = linalg.fill ins(%c-1_i32 : i32) outs(%33 : tensor<8xi32>) -> tensor<8xi32> + %reduced:2 = linalg.reduce ins(%32, %broadcasted : tensor<8x256xf32>, tensor<8x256xi32>) outs(%1, %34 : tensor<8xf32>, tensor<8xi32>) dimensions = [1] {reduce_mode = "min_with_index"} + (%in: f32, %in_6: i32, %init: f32, %init_7: i32) { + %41 = arith.cmpf olt, %in, %init : f32 + %42 = arith.cmpf oeq, %in, %init : f32 + %43 = arith.cmpi slt, %in_6, %init_7 : i32 + %44 = arith.andi %42, %43 : i1 + %45 = arith.ori %41, %44 : i1 + %46 = arith.select %45, %in, %init : f32 + %47 = arith.select %45, %in_6, %init_7 : i32 + linalg.yield %46, %47 : f32, i32 + } + %35 = arith.cmpf olt, %reduced#0, %arg13 : tensor<8xf32> + %36 = arith.select %35, %reduced#0, %arg13 : tensor<8xi1>, tensor<8xf32> + %37 = linalg.fill ins(%arg12 : i32) outs(%33 : tensor<8xi32>) -> tensor<8xi32> + %38 = arith.addi %37, %reduced#1 : tensor<8xi32> + %39 = arith.extsi %38 : tensor<8xi32> to tensor<8xi64> + %40 = arith.select %35, %39, %arg14 : tensor<8xi1>, tensor<8xi64> + scf.yield %36, %40 : tensor<8xf32>, tensor<8xi64> + } + %9 = arith.index_cast %4 : i32 to index + %reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%9], sizes: [8], strides: [1] : memref to memref<8xf32, strided<[1], offset: ?>> + %reinterpret_cast_0 = memref.reinterpret_cast %arg4 to offset: [%9], sizes: [8], strides: [1] : memref to memref<8xi64, strided<[1], offset: ?>> + %10 = arith.addi %9, %c8 : index + %11 = arith.maxsi %9, %c1 : index + %12 = arith.minsi %10, %11 : index + %13 = arith.subi %12, %9 : index + %extracted_slice = tensor.extract_slice %8#0[0] [%13] [1] : tensor<8xf32> to tensor + %subview = memref.subview %reinterpret_cast[0] [%13] [1] : memref<8xf32, strided<[1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor, memref>) -> () + %extracted_slice_1 = tensor.extract_slice %8#1[0] [%13] [1] : tensor<8xi64> to tensor + %subview_2 = memref.subview %reinterpret_cast_0[0] [%13] [1] : memref<8xi64, strided<[1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice_1 in writable %subview_2 : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/min_dim/triton-ascend/min_kernel.ttir b/python/test/ops/min_dim/triton-ascend/min_kernel.ttir new file mode 100644 index 000000000..6a9a03f4e --- /dev/null +++ b/python/test/ops/min_dim/triton-ascend/min_kernel.ttir @@ -0,0 +1,128 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":37:0) +#loc1 = loc(unknown) +#loc15 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":63:51) +#loc40 = loc(callsite(#loc1 at #loc15)) +module { + tt.func public @min_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":37:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":37:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":37:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":37:0)) attributes {noinline = false} { + %cst = arith.constant dense<0x7F800000> : tensor<8x256xf32> loc(#loc1) + %cst_0 = arith.constant dense<0x7F800000> : tensor<8xf32> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_1 = arith.constant dense<1> : tensor<8xi32> loc(#loc1) + %cst_2 = arith.constant dense<1> : tensor<8x1xi32> loc(#loc1) + %cst_3 = arith.constant dense<0> : tensor<8xi64> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c8_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc4) + %3 = tt.splat %1 : i32 -> tensor<8xi32> loc(#loc5) + %4 = arith.addi %3, %2 : tensor<8xi32> loc(#loc5) + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc6) + %6 = tt.expand_dims %4 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc7) + %7 = tt.splat %arg3 : i32 -> tensor<8x1xi32> loc(#loc8) + %8 = arith.muli %6, %7 : tensor<8x1xi32> loc(#loc8) + %9 = tt.broadcast %8 : tensor<8x1xi32> -> tensor<8x256xi32> loc(#loc9) + %10 = arith.cmpi slt, %6, %cst_2 : tensor<8x1xi32> loc(#loc10) + %11 = tt.splat %arg3 : i32 -> tensor<1x256xi32> loc(#loc11) + %12 = tt.broadcast %10 : tensor<8x1xi1> -> tensor<8x256xi1> loc(#loc12) + %13 = tt.splat %arg0 : !tt.ptr -> tensor<8x256x!tt.ptr> loc(#loc13) + %14 = tt.expand_dims %5 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc39) + %15 = tt.broadcast %14 : tensor<1x256xi32> -> tensor<8x256xi32> loc(#loc39) + %16:2 = scf.for %arg4 = %c0_i32 to %arg3 step %c256_i32 iter_args(%arg5 = %cst_0, %arg6 = %cst_3) -> (tensor<8xf32>, tensor<8xi64>) : i32 { + %22 = tt.splat %arg4 : i32 -> tensor<256xi32> loc(#loc17) + %23 = arith.addi %22, %5 : tensor<256xi32> loc(#loc17) + %24 = tt.expand_dims %23 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc18) + %25 = tt.broadcast %24 : tensor<1x256xi32> -> tensor<8x256xi32> loc(#loc9) + %26 = arith.addi %9, %25 : tensor<8x256xi32> loc(#loc9) + %27 = arith.cmpi slt, %24, %11 : tensor<1x256xi32> loc(#loc11) + %28 = tt.broadcast %27 : tensor<1x256xi1> -> tensor<8x256xi1> loc(#loc12) + %29 = arith.andi %12, %28 : tensor<8x256xi1> loc(#loc12) + %30 = tt.addptr %13, %26 : tensor<8x256x!tt.ptr>, tensor<8x256xi32> loc(#loc13) + %31 = tt.load %30, %29, %cst : tensor<8x256x!tt.ptr> loc(#loc19) + %32:2 = "tt.reduce"(%31, %15) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32 loc(callsite(#loc1 at #loc15)), %arg8: i32 loc(callsite(#loc1 at #loc15)), %arg9: f32 loc(callsite(#loc1 at #loc15)), %arg10: i32 loc(callsite(#loc1 at #loc15))): + %39 = arith.cmpf oeq, %arg7, %arg9 : f32 loc(#loc55) + %40 = arith.cmpi slt, %arg8, %arg10 : i32 loc(#loc56) + %41 = arith.andi %39, %40 : i1 loc(#loc57) + %42 = arith.cmpf olt, %arg7, %arg9 : f32 loc(#loc58) + %43 = arith.ori %42, %41 : i1 loc(#loc59) + %44 = arith.select %43, %arg7, %arg9 : f32 loc(#loc60) + %45 = arith.select %43, %arg8, %arg10 : i32 loc(#loc61) + tt.reduce.return %44, %45 : f32, i32 loc(#loc39) + }) : (tensor<8x256xf32>, tensor<8x256xi32>) -> (tensor<8xf32>, tensor<8xi32>) loc(#loc39) + %33 = arith.cmpf olt, %32#0, %arg5 : tensor<8xf32> loc(#loc28) + %34 = arith.select %33, %32#0, %arg5 : tensor<8xi1>, tensor<8xf32> loc(#loc29) + %35 = tt.splat %arg4 : i32 -> tensor<8xi32> loc(#loc30) + %36 = arith.addi %35, %32#1 : tensor<8xi32> loc(#loc30) + %37 = arith.extsi %36 : tensor<8xi32> to tensor<8xi64> loc(#loc31) + %38 = arith.select %33, %37, %arg6 : tensor<8xi1>, tensor<8xi64> loc(#loc31) + scf.yield %34, %38 : tensor<8xf32>, tensor<8xi64> loc(#loc32) + } loc(#loc16) + %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> loc(#loc33) + %18 = tt.addptr %17, %4 : tensor<8x!tt.ptr>, tensor<8xi32> loc(#loc33) + %19 = tt.splat %arg2 : !tt.ptr -> tensor<8x!tt.ptr> loc(#loc34) + %20 = tt.addptr %19, %4 : tensor<8x!tt.ptr>, tensor<8xi32> loc(#loc34) + %21 = arith.cmpi slt, %4, %cst_1 : tensor<8xi32> loc(#loc35) + tt.store %18, %16#0, %21 : tensor<8x!tt.ptr> loc(#loc36) + tt.store %20, %16#1, %21 : tensor<8x!tt.ptr> loc(#loc37) + tt.return loc(#loc38) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":47:26) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":48:23) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":48:46) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":48:33) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":58:42) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":59:26) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":59:37) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":59:41) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":60:35) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":60:61) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":60:41) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":61:25) +#loc14 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":233:58) +#loc16 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":57:31) +#loc17 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":58:29) +#loc18 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":59:50) +#loc19 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":62:27) +#loc20 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:24) +#loc21 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":212:59) +#loc22 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:44) +#loc23 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":201:35) +#loc24 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":204:18) +#loc25 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":204:28) +#loc26 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":205:39) +#loc27 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":206:39) +#loc28 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":64:29) +#loc29 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":65:49) +#loc30 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":66:51) +#loc31 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":66:65) +#loc32 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":66:8) +#loc33 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":70:33) +#loc34 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":71:33) +#loc35 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":72:23) +#loc36 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":73:29) +#loc37 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":74:29) +#loc38 = loc("/home/zhengyang/git/flagtree/python/test/ops/min_dim/min_dim.py":74:4) +#loc39 = loc(callsite(#loc14 at #loc15)) +#loc41 = loc(callsite(#loc20 at #loc21)) +#loc42 = loc(callsite(#loc22 at #loc21)) +#loc43 = loc(callsite(#loc23 at #loc21)) +#loc44 = loc(callsite(#loc24 at #loc21)) +#loc45 = loc(callsite(#loc25 at #loc21)) +#loc46 = loc(callsite(#loc26 at #loc21)) +#loc47 = loc(callsite(#loc27 at #loc21)) +#loc48 = loc(callsite(#loc41 at #loc14)) +#loc49 = loc(callsite(#loc42 at #loc14)) +#loc50 = loc(callsite(#loc43 at #loc14)) +#loc51 = loc(callsite(#loc44 at #loc14)) +#loc52 = loc(callsite(#loc45 at #loc14)) +#loc53 = loc(callsite(#loc46 at #loc14)) +#loc54 = loc(callsite(#loc47 at #loc14)) +#loc55 = loc(callsite(#loc48 at #loc15)) +#loc56 = loc(callsite(#loc49 at #loc15)) +#loc57 = loc(callsite(#loc50 at #loc15)) +#loc58 = loc(callsite(#loc51 at #loc15)) +#loc59 = loc(callsite(#loc52 at #loc15)) +#loc60 = loc(callsite(#loc53 at #loc15)) +#loc61 = loc(callsite(#loc54 at #loc15)) diff --git a/python/test/ops/sort/fusion_result.json b/python/test/ops/sort/fusion_result.json new file mode 100644 index 000000000..ec747fa47 --- /dev/null +++ b/python/test/ops/sort/fusion_result.json @@ -0,0 +1 @@ +null \ No newline at end of file diff --git a/python/test/ops/sort/sort.py b/python/test/ops/sort/sort.py new file mode 100644 index 000000000..a7f1a4461 --- /dev/null +++ b/python/test/ops/sort/sort.py @@ -0,0 +1,410 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +def unwrap_if_constexpr(o): + return o.value if isinstance(o, tl.constexpr) else o + + +@tl.constexpr +def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype: + num_bits = unwrap_if_constexpr(num_bits) + signed = unwrap_if_constexpr(signed) + return tl.core.get_int_dtype(num_bits, signed) + + +@tl.constexpr +def one_zeros(num_bits: tl.constexpr) -> int: + num_bits = unwrap_if_constexpr(num_bits) + return 1 << (num_bits - 1) + + +@tl.constexpr +def zero_ones(num_bits: tl.constexpr) -> int: + num_bits = unwrap_if_constexpr(num_bits) + return (1 << (num_bits - 1)) - 1 + + +@triton.jit +def uint_to_uint(x, descending: tl.constexpr = False): + out = ~x if descending else x + return out + + +@triton.jit +def uint_to_uint(x, descending: tl.constexpr = False): + out = ~x if descending else x + return out + + +@triton.jit +def int_to_uint(x, descending: tl.constexpr = False): + num_bits: tl.constexpr = x.dtype.primitive_bitwidth + udtype = get_int_t(num_bits, False) + ux = tl.cast(x, udtype, bitcast=True) + if descending: + # 0111111....1 + bit_mask: tl.constexpr = zero_ones(num_bits) + bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype) + out = ux ^ bit_mask_tensor + else: + # 1000000...0 + sign_bit_mask: tl.constexpr = one_zeros(num_bits) + sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype) + out = ux ^ sign_bit_mask_tensor + return out + + +@triton.jit +def floating_to_uint(x, descending: tl.constexpr = False): + num_bits: tl.constexpr = x.dtype.primitive_bitwidth + sdtype = get_int_t(num_bits, True) + udtype = get_int_t(num_bits, False) + sx = x.to(sdtype, bitcast=True) + ux = x.to(udtype, bitcast=True) + + sign_bit_mask_v: tl.constexpr = one_zeros(num_bits) + sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype) + # mind the dtype, right_shift for signed is arithmetic right shift + # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32 + rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype) + mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True) + tl.static_assert(mask.dtype == udtype, "type mismatch") + # 1000000000...0 for positive + # 1111111111...1 for negative + if descending: + out = ux ^ (~mask) + else: + out = ux ^ mask + return out.to(udtype, bitcast=True) + + +@triton.jit +def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False): + if x.dtype.is_floating(): + out = floating_to_uint(x, descending) + elif x.dtype.is_int_signed(): + out = int_to_uint(x, descending) + elif x.dtype.is_int_unsigned(): + out = uint_to_uint(x, descending) + return out + + +@triton.jit +def compute_global_hist_kernel( + arr_ptr, + out_ptr, + num_passes, + m, + n, + tiles_n_per_cta, + TILE_N: tl.constexpr, + TILE_R: tl.constexpr, + num_bits_per_pass: tl.constexpr, + descending: tl.constexpr, +): + # arr_ptr: (m, n) + # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins + pid = tl.program_id(0) + pid_n = pid // m + pid_m = pid % m + + r: tl.constexpr = 2**num_bits_per_pass + bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1 # a.k.a. 2 ** k_bits - 1 + CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta + cta_n_start = CTA_TILE_N * pid_n + cta_n_end = tl.minimum(cta_n_start + CTA_TILE_N, n) + + for p in range(0, num_passes): # parallel + bit_offset = p * num_bits_per_pass + for r_start in range(0, r, TILE_R): # parallel + bin_indices = r_start + tl.arange(0, TILE_R) + acc = tl.zeros((TILE_R, TILE_N), dtype=tl.int64) + for n_start in range(cta_n_start, cta_n_end, TILE_N): # sequantial + n_offsets = n_start + tl.arange(0, TILE_N) # (TILE_N, ) + mask = n_offsets < cta_n_end + arr = tl.load(arr_ptr + pid_m * n + n_offsets, mask=mask) + arr = convert_to_uint_preverse_order(arr, descending) + key = (arr >> bit_offset) & bfe_mask # (TILE_N, ) + matches = tl.where( + mask, (bin_indices[:, None] == key), False + ) # (TILE_R, TILE_N) + acc += matches + local_sum = tl.sum(acc, axis=1) + tl.atomic_add( + out_ptr + pid_m * num_passes * r + p * r + bin_indices, + local_sum, + sem="relaxed", + ) + + +@triton.jit +def sweep( + arr_ptr, + associate_arr_ptr, # inputs: (key & value) + out_ptr, + associate_out_ptr, # outputs: (key & value) + excumsum_bins_ptr, + status_ptr, # aux input and status + n_passes, + pass_id, + bit_offset, + m, + N, + OUT_N, + TILE_N: tl.constexpr, + TILE_R: tl.constexpr, + k_bits: tl.constexpr, + descending: tl.constexpr, +): + # r: num_bins = 2 ** k_bits + # OUT_N: grid_n = cdiv(N, ) + + # arr_ptr: (m, N) + # out_ptr: (m, N) + # excumsum_bins_ptr: (m, n_passes, r) + # flag_ptr: (m, r, OUT_N) + + # grid: (m, grid_r, grid_n) + + # load data + pid = tl.program_id(0) + pid_m = pid % m + pid_n = pid // m + pid_r = tl.program_id(1) + + # bit masks + aggregate_mask: tl.constexpr = 1 << 30 + inclusive_prefix_mask: tl.constexpr = 1 << 31 + v_mask: tl.constexpr = (1 << 30) - 1 + bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1 + + # initialize flag to zero-local sum is not ready + r: tl.constexpr = 2**k_bits + cta_r_start = pid_r * TILE_R + cta_r_end = tl.minimum(cta_r_start + TILE_R, r) + + # cumsum for a bin_index + n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, ) + mask = n_offsets < N + arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask) + arr_u = convert_to_uint_preverse_order(arr, descending) + key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, ) + + # since triton can only use scalar as condition, loop by bin_index + # status must be pre zero-initialized, or else we have to initialize it + for bin_index in range(cta_r_start, cta_r_end): + matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool + # cta level cumsum per bin + # CAUTION: tl.sum in triton 3.2 does not promote type + local_sum = tl.sum(matches.to(tl.uint32), axis=0) + pack0 = aggregate_mask | local_sum + status_offset = pid_m * (r * OUT_N) + bin_index * OUT_N + pid_n + tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg") + + # decoupled lookback + exclusive_prefix = tl.zeros((), dtype=tl.uint32) + i_lookback = pid_n - 1 + while i_lookback >= 0: + flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback + pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32 + while pack1 == 0: + pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) + exclusive_prefix += pack1 & v_mask + if (pack1 & aggregate_mask) == aggregate_mask: + i_lookback -= 1 + else: + i_lookback = -1 + pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum) + tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg") + + local_ex_cumsum = ( + tl.cumsum(matches.to(tl.uint32), axis=0) - matches + ) # (TILE_N, ) + ex_cumsum_in_bin = ( + exclusive_prefix + local_ex_cumsum + ) # global ex_cumsum_in_bin (TILE_N, ) + + # ex_cumsum_bins (m, n_passes, r) + ex_cumsum_bins = tl.load( + excumsum_bins_ptr + pid_m * (n_passes * r) + pass_id * r + bin_index + ) # scalar + pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, ) + + # scatter + tl.store(out_ptr + pid_m * N + pos, arr, mask=matches) + if associate_arr_ptr is not None: + associate_arr = tl.load( + associate_arr_ptr + pid_m * N + n_offsets, mask=mask + ) + tl.store(associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches) + + +def radix_sort(arr, k_bits=8, descending=False): + n = arr.shape[-1] + m = arr.numel() // n + assert n < (1 << 30), "we have not implemented 2**30 per launch" + dtype = arr.dtype + num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8) + + TILE_N = 1024 + tiles_n_per_cta = 8 + CTA_TILE_N = tiles_n_per_cta * TILE_N + + num_bins = 2**k_bits + n_passes = triton.cdiv(num_bits, k_bits) + TILE_R = 16 + + grid_n = triton.cdiv(n, CTA_TILE_N) + grid_for_global_hist = (m * grid_n, 1, 1) + + with torch_device_fn.device(arr.device): + global_hist = torch.zeros( + (m, n_passes, num_bins), device=arr.device, dtype=torch.int32 + ) + compute_global_hist_kernel[grid_for_global_hist]( + arr, + global_hist, + n_passes, + m, + n, + tiles_n_per_cta, + TILE_N, + TILE_R, + k_bits, + descending, + ) + ex_cumsum_bins = torch.cumsum(global_hist, -1) - global_hist + ex_cumsum_bins = ex_cumsum_bins.to(torch.uint32) + + # sort + arr_in = torch.clone(arr) + indices_in = ( + torch.arange(0, n, dtype=torch.int64, device=arr_in.device) + .broadcast_to(arr.shape) + .contiguous() + ) + arr_out = torch.empty_like(arr) + indices_out = torch.empty_like(indices_in) + + TILE_R = 8 + grid_r = triton.cdiv(num_bins, TILE_R) + TILE_N = 2048 + grid_n = triton.cdiv(n, TILE_N) + grid_for_sweep = (m * grid_n, grid_r) + + status = torch.empty( + (m, num_bins, grid_n), device=arr.device, dtype=torch.uint32 + ) + + for i in range(0, n_passes): + bit_offset = i * k_bits + status.zero_() + sweep[grid_for_sweep]( + arr_in, + indices_in, + arr_out, + indices_out, + ex_cumsum_bins, + status, + n_passes, + i, + bit_offset, + m, + n, + grid_n, + TILE_N, + TILE_R, + k_bits, + descending, + ) + # print(f"< sorted last {bit_offset + k_bits:>2d} bits: {arr_out}") + arr_in, arr_out = arr_out, arr_in + indices_in, indices_out = indices_out, indices_in + + return arr_in, indices_in + + +def sort_stable(inp, *, stable, dim=-1, descending=False): + # We only implement stable radix sort here + _ = stable + sort_elem_cnt = inp.shape[dim] + if sort_elem_cnt == 1: + return inp, torch.zeros_like(inp, dtype=torch.int64) + + if dim < 0: + dim = dim + inp.ndim + if dim != inp.ndim - 1: + inp = torch.movedim(inp, dim, -1).contiguous() + else: + inp = inp.contiguous() + + dtype = inp.dtype + num_bits_per_pass = 1 if dtype == torch.bool else 4 + out, out_index = radix_sort(inp, num_bits_per_pass, descending) + + if dim != inp.ndim - 1: + out = torch.movedim(out, -1, dim) + out_index = torch.movedim(out_index, -1, dim) + return out, out_index + + +def sort(inp, dim=-1, descending=False): + # We only implement stable radix sort here + return sort_stable(inp, stable=False, dim=dim, descending=descending) + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + shape = (2, 32) + dtype = torch.float32 + dim = -1 + + # inp + inp = torch.randn(shape, dtype=dtype, device=device) + ref_inp = inp.cpu() + + # op + ref_value, ref_value = torch.sort(ref_inp, dim=dim, stable=True, descending=False) + res_index, ref_index = sort(inp, dim=dim, descending=False) + check("value", ref_value, res_value) + check("index", ref_index, res_index) diff --git a/python/test/ops/sort/sort_ascend.py b/python/test/ops/sort/sort_ascend.py new file mode 100644 index 000000000..f2cb51b95 --- /dev/null +++ b/python/test/ops/sort/sort_ascend.py @@ -0,0 +1,282 @@ +import math + +import torch +import triton +import triton.language as tl +import triton.language.core as core +from triton.language.standard import _log2, zeros_like + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min) +_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max) +_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min) +_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max) +_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min) +_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max) +_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min) +_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max) +_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) +_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) +_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) +_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) +_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) +_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) + + +@triton.jit +def _get_finfo_val( + dtype, + return_max, +): + if dtype is tl.float32: + if return_max: + return _MAX_FLOAT32_VAL + else: + return _MIN_FLOAT32_VAL + elif dtype is tl.float16: + if return_max: + return _MAX_FLOAT16_VAL + else: + return _MIN_FLOAT16_VAL + elif dtype is tl.bfloat16: + if return_max: + return _MAX_BFLOAT16_VAL + else: + return _MIN_BFLOAT16_VAL + + +@triton.jit +def _get_iinfo_val( + dtype, + return_max, +): + if return_max: + return get_dtype_max(dtype) + else: + return get_dtype_min(dtype) + + +@triton.jit +def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): + n_outer: core.constexpr = x.numel >> n_dims + shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] + + # tl.device_print("shape is: ", shape) + y = core.reshape(x, shape) + y_idx = core.reshape(ids, shape) + + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = core.arange(0, 2)[None, :, None] + left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) + right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) + left = core.reshape(left, x.shape) + right = core.reshape(right, x.shape) + + left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( + ids.dtype + ) + right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( + ids.dtype + ) + left_idx = core.reshape(left_idx, ids.shape) + right_idx = core.reshape(right_idx, ids.shape) + + # actual compare-and-swap + if core.constexpr(x.dtype.primitive_bitwidth) == 8: + idtype = core.int8 + elif core.constexpr(x.dtype.primitive_bitwidth) == 16: + idtype = core.int16 + elif core.constexpr(x.dtype.primitive_bitwidth) == 32: + idtype = core.int32 + elif core.constexpr(x.dtype.primitive_bitwidth) == 64: + idtype = core.int64 + else: + raise ValueError("Unsupported dtype") + + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + cond = (left > right) ^ flip + ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix)) + + if core.constexpr(ids.dtype.primitive_bitwidth) == 8: + idx_dtype = core.int8 + elif core.constexpr(ids.dtype.primitive_bitwidth) == 16: + idx_dtype = core.int16 + elif core.constexpr(ids.dtype.primitive_bitwidth) == 32: + idx_dtype = core.int32 + elif core.constexpr(ids.dtype.primitive_bitwidth) == 64: + idx_dtype = core.int64 + else: + raise ValueError("Unsupported dtype") + + ileft_idx = left_idx.to(idx_dtype, bitcast=True) + iright_idx = right_idx.to(idx_dtype, bitcast=True) + ix_idx = ids.to(idx_dtype, bitcast=True) + ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx)) + + return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True) + + +@triton.jit +def _bitonic_merge( + x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr +): + """ + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + """ + n_outer: core.constexpr = x.numel >> n_dims + core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] + flip = core.reshape( + core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape + ) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) + return x, ids + + +@triton.jit +def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = dim + n_dims: core.constexpr = _log2(x.shape[_dim]) + for i in core.static_range(1, n_dims + 1): + x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) + return x, ids + + +@triton.jit() +def sort_kernel( + in_ptr, + out_ptr, + out_index_ptr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + DESCENDING: tl.constexpr, + IS_FLOAT: tl.constexpr, +): + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < N + offset = tl.program_id(0) * N + cols + in_ptr += offset + out_ptr += offset + out_index_ptr += offset + + if IS_FLOAT: + mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) + in_val = tl.load(in_ptr, mask=mask, other=mask_val) + in_val = tl.where(in_val.dtype.is_fp64(), in_val, in_val.to(tl.float32)) + else: + mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) + in_val = tl.load(in_ptr, mask=mask, other=mask_val).to(tl.int32) + index_val = tl.arange(0, BLOCK_SIZE) + + sorted_in_val, sorted_index_val = argsort( + in_val, index_val, 0, descending=DESCENDING + ) + tl.store(out_ptr, sorted_in_val, mask=mask) + tl.store(out_index_ptr, sorted_index_val, mask=mask) + + +def sort(inp, dim=-1, descending=False): + sort_elem_cnt = inp.shape[dim] + if sort_elem_cnt == 1: + return inp, torch.zeros_like(inp, dtype=torch.int64) + elif sort_elem_cnt > 128: # TODO: Optimize implementation for large cases. + return torch.sort(inp, stable=False, dim=dim, descending=descending) + block_size = triton.next_power_of_2(sort_elem_cnt) + + if dim < 0: + dim = dim + inp.ndim + if dim != inp.ndim - 1: + inp = torch.movedim(inp, dim, -1).contiguous() + else: + inp = inp.contiguous() + batch_size = math.prod(inp.shape) // sort_elem_cnt + + out = torch.empty_like(inp) + out_index = torch.empty_like(inp, dtype=torch.int64) + + with torch_device_fn.device(inp.device): + sort_kernel[batch_size,]( + inp, + out, + out_index, + N=sort_elem_cnt, + BLOCK_SIZE=block_size, + DESCENDING=descending, + IS_FLOAT=inp.is_floating_point(), + ) + + if dim != inp.ndim - 1: + out = torch.movedim(out, -1, dim) + out_index = torch.movedim(out_index, -1, dim) + return out, out_index + + +def check(name, ref, res, equal_nan=False, reduce_dim=1, atol=1e-4): + RESOLUTION = { + torch.bool: 0, + torch.uint8: 0, + torch.int8: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float8_e4m3fn: 1e-3, + torch.float8_e5m2: 1e-3, + torch.float8_e4m3fnuz: 1e-3, + torch.float8_e5m2fnuz: 1e-3, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, + } + res = res.cpu() + print( + f"The maximum difference out {name} between torch and triton is " + f"{torch.max(torch.abs(ref - res))}" + ) + rtol = RESOLUTION[ref.dtype] + assert torch.allclose(res, ref, atol=atol * reduce_dim, rtol=rtol), (res, ref) + + +if __name__ == "__main__": + # param + shape = (2, 32) + dtype = torch.float32 + dim = -1 + + # inp + inp = torch.randn(shape, dtype=dtype, device=device) + ref_inp = inp.cpu() + + # op + ref_value, ref_value = torch.sort(ref_inp, dim=dim, stable=True, descending=False) + res_index, ref_index = sort(inp, dim=dim, descending=False) + check("value", ref_value, res_value) + check("index", ref_index, res_index) diff --git a/python/test/ops/sort/triton-ascend-failed/sort_kernel.ttadapter b/python/test/ops/sort/triton-ascend-failed/sort_kernel.ttadapter new file mode 100644 index 000000000..c0f48c2cf --- /dev/null +++ b/python/test/ops/sort/triton-ascend-failed/sort_kernel.ttadapter @@ -0,0 +1,805 @@ +#map = affine_map<(d0) -> (d0)> +module { + func.func @sort_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %cst = arith.constant dense<[1, 2, 16]> : tensor<3xi64> + %cst_0 = arith.constant dense<[2, 2, 8]> : tensor<3xi64> + %cst_1 = arith.constant dense<[4, 2, 4]> : tensor<3xi64> + %cst_2 = arith.constant dense<[8, 2, 2]> : tensor<3xi64> + %cst_3 = arith.constant 0.000000e+00 : f32 + %cst_4 = arith.constant dense<[16, 2, 1]> : tensor<3xi64> + %cst_5 = arith.constant dense<32> : tensor<1xi64> + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tensor.empty() : tensor<32xi32> + %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<32xi32>) -> tensor<32xi32> + %2 = tensor.empty() : tensor<1x2x1xi32> + %3 = linalg.fill ins(%c1_i32 : i32) outs(%2 : tensor<1x2x1xi32>) -> tensor<1x2x1xi32> + %4 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%0 : tensor<32xi32>) { + ^bb0(%out: i32): + %285 = linalg.index 0 : index + %286 = arith.index_cast %285 : index to i32 + linalg.yield %286 : i32 + } -> tensor<32xi32> + %5 = arith.muli %arg8, %c32_i32 : i32 + %6 = arith.index_cast %5 : i32 to index + %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%6], sizes: [32], strides: [1] : memref to memref<32xf32, strided<[1], offset: ?>> + %reinterpret_cast_6 = memref.reinterpret_cast %arg3 to offset: [%6], sizes: [32], strides: [1] : memref to memref<32xf32, strided<[1], offset: ?>> + %reinterpret_cast_7 = memref.reinterpret_cast %arg4 to offset: [%6], sizes: [32], strides: [1] : memref to memref<32xi64, strided<[1], offset: ?>> + %alloc = memref.alloc() : memref<32xf32> + memref.copy %reinterpret_cast, %alloc : memref<32xf32, strided<[1], offset: ?>> to memref<32xf32> + %7 = bufferization.to_tensor %alloc restrict writable : memref<32xf32> + %8 = tensor.empty() : tensor<2xi32> + %9 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%8 : tensor<2xi32>) { + ^bb0(%out: i32): + %285 = linalg.index 0 : index + %286 = arith.index_cast %285 : index to i32 + linalg.yield %286 : i32 + } -> tensor<2xi32> + %expanded = tensor.expand_shape %9 [[0, 1, 2]] output_shape [1, 2, 1] : tensor<2xi32> into tensor<1x2x1xi32> + %10 = tensor.empty() : tensor<8x2x2xi32> + %broadcasted = linalg.broadcast ins(%9 : tensor<2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [0, 2] + %reshape = tensor.reshape %broadcasted(%cst_5) : (tensor<8x2x2xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_8 = tensor.reshape %7(%cst_4) : (tensor<32xf32>, tensor<3xi64>) -> tensor<16x2x1xf32> + %reshape_9 = tensor.reshape %4(%cst_4) : (tensor<32xi32>, tensor<3xi64>) -> tensor<16x2x1xi32> + %11 = arith.subi %3, %expanded : tensor<1x2x1xi32> + %12 = arith.sitofp %11 : tensor<1x2x1xi32> to tensor<1x2x1xf32> + %13 = tensor.empty() : tensor<16x2x1xf32> + %collapsed = tensor.collapse_shape %12 [[0, 1], [2]] : tensor<1x2x1xf32> into tensor<2x1xf32> + %broadcasted_10 = linalg.broadcast ins(%collapsed : tensor<2x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [0] + %14 = arith.mulf %reshape_8, %broadcasted_10 : tensor<16x2x1xf32> + %15 = tensor.empty() : tensor<16x1xf32> + %16 = linalg.fill ins(%cst_3 : f32) outs(%15 : tensor<16x1xf32>) -> tensor<16x1xf32> + %reduced = linalg.reduce ins(%14 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_11 = linalg.broadcast ins(%reduced : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %17 = arith.sitofp %expanded : tensor<1x2x1xi32> to tensor<1x2x1xf32> + %collapsed_12 = tensor.collapse_shape %17 [[0, 1], [2]] : tensor<1x2x1xf32> into tensor<2x1xf32> + %broadcasted_13 = linalg.broadcast ins(%collapsed_12 : tensor<2x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [0] + %18 = arith.mulf %reshape_8, %broadcasted_13 : tensor<16x2x1xf32> + %reduced_14 = linalg.reduce ins(%18 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_15 = linalg.broadcast ins(%reduced_14 : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %reshape_16 = tensor.reshape %broadcasted_11(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_17 = tensor.reshape %broadcasted_15(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %19 = tensor.empty() : tensor<16x2x1xi32> + %collapsed_18 = tensor.collapse_shape %11 [[0, 1], [2]] : tensor<1x2x1xi32> into tensor<2x1xi32> + %broadcasted_19 = linalg.broadcast ins(%collapsed_18 : tensor<2x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [0] + %20 = arith.muli %reshape_9, %broadcasted_19 : tensor<16x2x1xi32> + %21 = tensor.empty() : tensor<16x1xi32> + %22 = linalg.fill ins(%c0_i32 : i32) outs(%21 : tensor<16x1xi32>) -> tensor<16x1xi32> + %reduced_20 = linalg.reduce ins(%20 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_21 = linalg.broadcast ins(%reduced_20 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %expanded_22 = tensor.expand_shape %9 [[0, 1]] output_shape [2, 1] : tensor<2xi32> into tensor<2x1xi32> + %broadcasted_23 = linalg.broadcast ins(%expanded_22 : tensor<2x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [0] + %23 = arith.muli %reshape_9, %broadcasted_23 : tensor<16x2x1xi32> + %reduced_24 = linalg.reduce ins(%23 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_25 = linalg.broadcast ins(%reduced_24 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %reshape_26 = tensor.reshape %broadcasted_21(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_27 = tensor.reshape %broadcasted_25(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %24 = arith.bitcast %reshape_16 : tensor<32xf32> to tensor<32xi32> + %25 = arith.bitcast %reshape_17 : tensor<32xf32> to tensor<32xi32> + %26 = arith.bitcast %7 : tensor<32xf32> to tensor<32xi32> + %27 = arith.cmpf ogt, %reshape_16, %reshape_17 : tensor<32xf32> + %28 = arith.extui %27 : tensor<32xi1> to tensor<32xi32> + %29 = arith.xori %28, %reshape : tensor<32xi32> + %30 = arith.xori %24, %25 : tensor<32xi32> + %31 = arith.cmpi ne, %29, %1 : tensor<32xi32> + %32 = arith.select %31, %30, %1 : tensor<32xi1>, tensor<32xi32> + %33 = arith.xori %26, %32 : tensor<32xi32> + %34 = arith.xori %reshape_26, %reshape_27 : tensor<32xi32> + %35 = arith.select %31, %34, %1 : tensor<32xi1>, tensor<32xi32> + %36 = arith.xori %4, %35 : tensor<32xi32> + %37 = arith.bitcast %33 : tensor<32xi32> to tensor<32xf32> + %38 = tensor.empty() : tensor<4x2x4xi32> + %broadcasted_28 = linalg.broadcast ins(%9 : tensor<2xi32>) outs(%38 : tensor<4x2x4xi32>) dimensions = [0, 2] + %reshape_29 = tensor.reshape %broadcasted_28(%cst_5) : (tensor<4x2x4xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_30 = tensor.reshape %37(%cst_2) : (tensor<32xf32>, tensor<3xi64>) -> tensor<8x2x2xf32> + %reshape_31 = tensor.reshape %36(%cst_2) : (tensor<32xi32>, tensor<3xi64>) -> tensor<8x2x2xi32> + %39 = tensor.empty() : tensor<8x2x2xf32> + %collapsed_32 = tensor.collapse_shape %12 [[0, 1, 2]] : tensor<1x2x1xf32> into tensor<2xf32> + %broadcasted_33 = linalg.broadcast ins(%collapsed_32 : tensor<2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [0, 2] + %40 = arith.mulf %reshape_30, %broadcasted_33 : tensor<8x2x2xf32> + %41 = tensor.empty() : tensor<8x2xf32> + %42 = linalg.fill ins(%cst_3 : f32) outs(%41 : tensor<8x2xf32>) -> tensor<8x2xf32> + %reduced_34 = linalg.reduce ins(%40 : tensor<8x2x2xf32>) outs(%42 : tensor<8x2xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_35 = linalg.broadcast ins(%reduced_34 : tensor<8x2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [1] + %collapsed_36 = tensor.collapse_shape %17 [[0, 1, 2]] : tensor<1x2x1xf32> into tensor<2xf32> + %broadcasted_37 = linalg.broadcast ins(%collapsed_36 : tensor<2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [0, 2] + %43 = arith.mulf %reshape_30, %broadcasted_37 : tensor<8x2x2xf32> + %reduced_38 = linalg.reduce ins(%43 : tensor<8x2x2xf32>) outs(%42 : tensor<8x2xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_39 = linalg.broadcast ins(%reduced_38 : tensor<8x2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [1] + %reshape_40 = tensor.reshape %broadcasted_35(%cst_5) : (tensor<8x2x2xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_41 = tensor.reshape %broadcasted_39(%cst_5) : (tensor<8x2x2xf32>, tensor<1xi64>) -> tensor<32xf32> + %collapsed_42 = tensor.collapse_shape %11 [[0, 1, 2]] : tensor<1x2x1xi32> into tensor<2xi32> + %broadcasted_43 = linalg.broadcast ins(%collapsed_42 : tensor<2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [0, 2] + %44 = arith.muli %reshape_31, %broadcasted_43 : tensor<8x2x2xi32> + %45 = tensor.empty() : tensor<8x2xi32> + %46 = linalg.fill ins(%c0_i32 : i32) outs(%45 : tensor<8x2xi32>) -> tensor<8x2xi32> + %reduced_44 = linalg.reduce ins(%44 : tensor<8x2x2xi32>) outs(%46 : tensor<8x2xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_45 = linalg.broadcast ins(%reduced_44 : tensor<8x2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [1] + %47 = arith.muli %reshape_31, %broadcasted : tensor<8x2x2xi32> + %reduced_46 = linalg.reduce ins(%47 : tensor<8x2x2xi32>) outs(%46 : tensor<8x2xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_47 = linalg.broadcast ins(%reduced_46 : tensor<8x2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [1] + %reshape_48 = tensor.reshape %broadcasted_45(%cst_5) : (tensor<8x2x2xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_49 = tensor.reshape %broadcasted_47(%cst_5) : (tensor<8x2x2xi32>, tensor<1xi64>) -> tensor<32xi32> + %48 = arith.bitcast %reshape_40 : tensor<32xf32> to tensor<32xi32> + %49 = arith.bitcast %reshape_41 : tensor<32xf32> to tensor<32xi32> + %50 = arith.cmpf ogt, %reshape_40, %reshape_41 : tensor<32xf32> + %51 = arith.extui %50 : tensor<32xi1> to tensor<32xi32> + %52 = arith.xori %51, %reshape_29 : tensor<32xi32> + %53 = arith.xori %48, %49 : tensor<32xi32> + %54 = arith.cmpi ne, %52, %1 : tensor<32xi32> + %55 = arith.select %54, %53, %1 : tensor<32xi1>, tensor<32xi32> + %56 = arith.xori %33, %55 : tensor<32xi32> + %57 = arith.xori %reshape_48, %reshape_49 : tensor<32xi32> + %58 = arith.select %54, %57, %1 : tensor<32xi1>, tensor<32xi32> + %59 = arith.xori %36, %58 : tensor<32xi32> + %60 = arith.bitcast %56 : tensor<32xi32> to tensor<32xf32> + %reshape_50 = tensor.reshape %60(%cst_4) : (tensor<32xf32>, tensor<3xi64>) -> tensor<16x2x1xf32> + %reshape_51 = tensor.reshape %59(%cst_4) : (tensor<32xi32>, tensor<3xi64>) -> tensor<16x2x1xi32> + %61 = arith.mulf %reshape_50, %broadcasted_10 : tensor<16x2x1xf32> + %reduced_52 = linalg.reduce ins(%61 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_53 = linalg.broadcast ins(%reduced_52 : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %62 = arith.mulf %reshape_50, %broadcasted_13 : tensor<16x2x1xf32> + %reduced_54 = linalg.reduce ins(%62 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_55 = linalg.broadcast ins(%reduced_54 : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %reshape_56 = tensor.reshape %broadcasted_53(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_57 = tensor.reshape %broadcasted_55(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %63 = arith.muli %reshape_51, %broadcasted_19 : tensor<16x2x1xi32> + %reduced_58 = linalg.reduce ins(%63 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_59 = linalg.broadcast ins(%reduced_58 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %64 = arith.muli %reshape_51, %broadcasted_23 : tensor<16x2x1xi32> + %reduced_60 = linalg.reduce ins(%64 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_61 = linalg.broadcast ins(%reduced_60 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %reshape_62 = tensor.reshape %broadcasted_59(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_63 = tensor.reshape %broadcasted_61(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %65 = arith.bitcast %reshape_56 : tensor<32xf32> to tensor<32xi32> + %66 = arith.bitcast %reshape_57 : tensor<32xf32> to tensor<32xi32> + %67 = arith.cmpf ogt, %reshape_56, %reshape_57 : tensor<32xf32> + %68 = arith.extui %67 : tensor<32xi1> to tensor<32xi32> + %69 = arith.xori %68, %reshape_29 : tensor<32xi32> + %70 = arith.xori %65, %66 : tensor<32xi32> + %71 = arith.cmpi ne, %69, %1 : tensor<32xi32> + %72 = arith.select %71, %70, %1 : tensor<32xi1>, tensor<32xi32> + %73 = arith.xori %56, %72 : tensor<32xi32> + %74 = arith.xori %reshape_62, %reshape_63 : tensor<32xi32> + %75 = arith.select %71, %74, %1 : tensor<32xi1>, tensor<32xi32> + %76 = arith.xori %59, %75 : tensor<32xi32> + %77 = arith.bitcast %73 : tensor<32xi32> to tensor<32xf32> + %78 = tensor.empty() : tensor<2x2x8xi32> + %broadcasted_64 = linalg.broadcast ins(%9 : tensor<2xi32>) outs(%78 : tensor<2x2x8xi32>) dimensions = [0, 2] + %reshape_65 = tensor.reshape %broadcasted_64(%cst_5) : (tensor<2x2x8xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_66 = tensor.reshape %77(%cst_1) : (tensor<32xf32>, tensor<3xi64>) -> tensor<4x2x4xf32> + %reshape_67 = tensor.reshape %76(%cst_1) : (tensor<32xi32>, tensor<3xi64>) -> tensor<4x2x4xi32> + %79 = tensor.empty() : tensor<4x2x4xf32> + %broadcasted_68 = linalg.broadcast ins(%collapsed_32 : tensor<2xf32>) outs(%79 : tensor<4x2x4xf32>) dimensions = [0, 2] + %80 = arith.mulf %reshape_66, %broadcasted_68 : tensor<4x2x4xf32> + %81 = tensor.empty() : tensor<4x4xf32> + %82 = linalg.fill ins(%cst_3 : f32) outs(%81 : tensor<4x4xf32>) -> tensor<4x4xf32> + %reduced_69 = linalg.reduce ins(%80 : tensor<4x2x4xf32>) outs(%82 : tensor<4x4xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_70 = linalg.broadcast ins(%reduced_69 : tensor<4x4xf32>) outs(%79 : tensor<4x2x4xf32>) dimensions = [1] + %broadcasted_71 = linalg.broadcast ins(%collapsed_36 : tensor<2xf32>) outs(%79 : tensor<4x2x4xf32>) dimensions = [0, 2] + %83 = arith.mulf %reshape_66, %broadcasted_71 : tensor<4x2x4xf32> + %reduced_72 = linalg.reduce ins(%83 : tensor<4x2x4xf32>) outs(%82 : tensor<4x4xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_73 = linalg.broadcast ins(%reduced_72 : tensor<4x4xf32>) outs(%79 : tensor<4x2x4xf32>) dimensions = [1] + %reshape_74 = tensor.reshape %broadcasted_70(%cst_5) : (tensor<4x2x4xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_75 = tensor.reshape %broadcasted_73(%cst_5) : (tensor<4x2x4xf32>, tensor<1xi64>) -> tensor<32xf32> + %broadcasted_76 = linalg.broadcast ins(%collapsed_42 : tensor<2xi32>) outs(%38 : tensor<4x2x4xi32>) dimensions = [0, 2] + %84 = arith.muli %reshape_67, %broadcasted_76 : tensor<4x2x4xi32> + %85 = tensor.empty() : tensor<4x4xi32> + %86 = linalg.fill ins(%c0_i32 : i32) outs(%85 : tensor<4x4xi32>) -> tensor<4x4xi32> + %reduced_77 = linalg.reduce ins(%84 : tensor<4x2x4xi32>) outs(%86 : tensor<4x4xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_78 = linalg.broadcast ins(%reduced_77 : tensor<4x4xi32>) outs(%38 : tensor<4x2x4xi32>) dimensions = [1] + %87 = arith.muli %reshape_67, %broadcasted_28 : tensor<4x2x4xi32> + %reduced_79 = linalg.reduce ins(%87 : tensor<4x2x4xi32>) outs(%86 : tensor<4x4xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_80 = linalg.broadcast ins(%reduced_79 : tensor<4x4xi32>) outs(%38 : tensor<4x2x4xi32>) dimensions = [1] + %reshape_81 = tensor.reshape %broadcasted_78(%cst_5) : (tensor<4x2x4xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_82 = tensor.reshape %broadcasted_80(%cst_5) : (tensor<4x2x4xi32>, tensor<1xi64>) -> tensor<32xi32> + %88 = arith.bitcast %reshape_74 : tensor<32xf32> to tensor<32xi32> + %89 = arith.bitcast %reshape_75 : tensor<32xf32> to tensor<32xi32> + %90 = arith.cmpf ogt, %reshape_74, %reshape_75 : tensor<32xf32> + %91 = arith.extui %90 : tensor<32xi1> to tensor<32xi32> + %92 = arith.xori %91, %reshape_65 : tensor<32xi32> + %93 = arith.xori %88, %89 : tensor<32xi32> + %94 = arith.cmpi ne, %92, %1 : tensor<32xi32> + %95 = arith.select %94, %93, %1 : tensor<32xi1>, tensor<32xi32> + %96 = arith.xori %73, %95 : tensor<32xi32> + %97 = arith.xori %reshape_81, %reshape_82 : tensor<32xi32> + %98 = arith.select %94, %97, %1 : tensor<32xi1>, tensor<32xi32> + %99 = arith.xori %76, %98 : tensor<32xi32> + %100 = arith.bitcast %96 : tensor<32xi32> to tensor<32xf32> + %reshape_83 = tensor.reshape %100(%cst_2) : (tensor<32xf32>, tensor<3xi64>) -> tensor<8x2x2xf32> + %reshape_84 = tensor.reshape %99(%cst_2) : (tensor<32xi32>, tensor<3xi64>) -> tensor<8x2x2xi32> + %101 = arith.mulf %reshape_83, %broadcasted_33 : tensor<8x2x2xf32> + %reduced_85 = linalg.reduce ins(%101 : tensor<8x2x2xf32>) outs(%42 : tensor<8x2xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_86 = linalg.broadcast ins(%reduced_85 : tensor<8x2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [1] + %102 = arith.mulf %reshape_83, %broadcasted_37 : tensor<8x2x2xf32> + %reduced_87 = linalg.reduce ins(%102 : tensor<8x2x2xf32>) outs(%42 : tensor<8x2xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_88 = linalg.broadcast ins(%reduced_87 : tensor<8x2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [1] + %reshape_89 = tensor.reshape %broadcasted_86(%cst_5) : (tensor<8x2x2xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_90 = tensor.reshape %broadcasted_88(%cst_5) : (tensor<8x2x2xf32>, tensor<1xi64>) -> tensor<32xf32> + %103 = arith.muli %reshape_84, %broadcasted_43 : tensor<8x2x2xi32> + %reduced_91 = linalg.reduce ins(%103 : tensor<8x2x2xi32>) outs(%46 : tensor<8x2xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_92 = linalg.broadcast ins(%reduced_91 : tensor<8x2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [1] + %104 = arith.muli %reshape_84, %broadcasted : tensor<8x2x2xi32> + %reduced_93 = linalg.reduce ins(%104 : tensor<8x2x2xi32>) outs(%46 : tensor<8x2xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_94 = linalg.broadcast ins(%reduced_93 : tensor<8x2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [1] + %reshape_95 = tensor.reshape %broadcasted_92(%cst_5) : (tensor<8x2x2xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_96 = tensor.reshape %broadcasted_94(%cst_5) : (tensor<8x2x2xi32>, tensor<1xi64>) -> tensor<32xi32> + %105 = arith.bitcast %reshape_89 : tensor<32xf32> to tensor<32xi32> + %106 = arith.bitcast %reshape_90 : tensor<32xf32> to tensor<32xi32> + %107 = arith.cmpf ogt, %reshape_89, %reshape_90 : tensor<32xf32> + %108 = arith.extui %107 : tensor<32xi1> to tensor<32xi32> + %109 = arith.xori %108, %reshape_65 : tensor<32xi32> + %110 = arith.xori %105, %106 : tensor<32xi32> + %111 = arith.cmpi ne, %109, %1 : tensor<32xi32> + %112 = arith.select %111, %110, %1 : tensor<32xi1>, tensor<32xi32> + %113 = arith.xori %96, %112 : tensor<32xi32> + %114 = arith.xori %reshape_95, %reshape_96 : tensor<32xi32> + %115 = arith.select %111, %114, %1 : tensor<32xi1>, tensor<32xi32> + %116 = arith.xori %99, %115 : tensor<32xi32> + %117 = arith.bitcast %113 : tensor<32xi32> to tensor<32xf32> + %reshape_97 = tensor.reshape %117(%cst_4) : (tensor<32xf32>, tensor<3xi64>) -> tensor<16x2x1xf32> + %reshape_98 = tensor.reshape %116(%cst_4) : (tensor<32xi32>, tensor<3xi64>) -> tensor<16x2x1xi32> + %118 = arith.mulf %reshape_97, %broadcasted_10 : tensor<16x2x1xf32> + %reduced_99 = linalg.reduce ins(%118 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_100 = linalg.broadcast ins(%reduced_99 : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %119 = arith.mulf %reshape_97, %broadcasted_13 : tensor<16x2x1xf32> + %reduced_101 = linalg.reduce ins(%119 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_102 = linalg.broadcast ins(%reduced_101 : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %reshape_103 = tensor.reshape %broadcasted_100(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_104 = tensor.reshape %broadcasted_102(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %120 = arith.muli %reshape_98, %broadcasted_19 : tensor<16x2x1xi32> + %reduced_105 = linalg.reduce ins(%120 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_106 = linalg.broadcast ins(%reduced_105 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %121 = arith.muli %reshape_98, %broadcasted_23 : tensor<16x2x1xi32> + %reduced_107 = linalg.reduce ins(%121 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_108 = linalg.broadcast ins(%reduced_107 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %reshape_109 = tensor.reshape %broadcasted_106(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_110 = tensor.reshape %broadcasted_108(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %122 = arith.bitcast %reshape_103 : tensor<32xf32> to tensor<32xi32> + %123 = arith.bitcast %reshape_104 : tensor<32xf32> to tensor<32xi32> + %124 = arith.cmpf ogt, %reshape_103, %reshape_104 : tensor<32xf32> + %125 = arith.extui %124 : tensor<32xi1> to tensor<32xi32> + %126 = arith.xori %125, %reshape_65 : tensor<32xi32> + %127 = arith.xori %122, %123 : tensor<32xi32> + %128 = arith.cmpi ne, %126, %1 : tensor<32xi32> + %129 = arith.select %128, %127, %1 : tensor<32xi1>, tensor<32xi32> + %130 = arith.xori %113, %129 : tensor<32xi32> + %131 = arith.xori %reshape_109, %reshape_110 : tensor<32xi32> + %132 = arith.select %128, %131, %1 : tensor<32xi1>, tensor<32xi32> + %133 = arith.xori %116, %132 : tensor<32xi32> + %134 = arith.bitcast %130 : tensor<32xi32> to tensor<32xf32> + %135 = tensor.empty() : tensor<1x2x16xi32> + %expanded_111 = tensor.expand_shape %9 [[0, 1]] output_shape [1, 2] : tensor<2xi32> into tensor<1x2xi32> + %broadcasted_112 = linalg.broadcast ins(%expanded_111 : tensor<1x2xi32>) outs(%135 : tensor<1x2x16xi32>) dimensions = [2] + %reshape_113 = tensor.reshape %broadcasted_112(%cst_5) : (tensor<1x2x16xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_114 = tensor.reshape %134(%cst_0) : (tensor<32xf32>, tensor<3xi64>) -> tensor<2x2x8xf32> + %reshape_115 = tensor.reshape %133(%cst_0) : (tensor<32xi32>, tensor<3xi64>) -> tensor<2x2x8xi32> + %136 = tensor.empty() : tensor<2x2x8xf32> + %broadcasted_116 = linalg.broadcast ins(%collapsed_32 : tensor<2xf32>) outs(%136 : tensor<2x2x8xf32>) dimensions = [0, 2] + %137 = arith.mulf %reshape_114, %broadcasted_116 : tensor<2x2x8xf32> + %138 = tensor.empty() : tensor<2x8xf32> + %139 = linalg.fill ins(%cst_3 : f32) outs(%138 : tensor<2x8xf32>) -> tensor<2x8xf32> + %reduced_117 = linalg.reduce ins(%137 : tensor<2x2x8xf32>) outs(%139 : tensor<2x8xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_118 = linalg.broadcast ins(%reduced_117 : tensor<2x8xf32>) outs(%136 : tensor<2x2x8xf32>) dimensions = [1] + %broadcasted_119 = linalg.broadcast ins(%collapsed_36 : tensor<2xf32>) outs(%136 : tensor<2x2x8xf32>) dimensions = [0, 2] + %140 = arith.mulf %reshape_114, %broadcasted_119 : tensor<2x2x8xf32> + %reduced_120 = linalg.reduce ins(%140 : tensor<2x2x8xf32>) outs(%139 : tensor<2x8xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_121 = linalg.broadcast ins(%reduced_120 : tensor<2x8xf32>) outs(%136 : tensor<2x2x8xf32>) dimensions = [1] + %reshape_122 = tensor.reshape %broadcasted_118(%cst_5) : (tensor<2x2x8xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_123 = tensor.reshape %broadcasted_121(%cst_5) : (tensor<2x2x8xf32>, tensor<1xi64>) -> tensor<32xf32> + %broadcasted_124 = linalg.broadcast ins(%collapsed_42 : tensor<2xi32>) outs(%78 : tensor<2x2x8xi32>) dimensions = [0, 2] + %141 = arith.muli %reshape_115, %broadcasted_124 : tensor<2x2x8xi32> + %142 = tensor.empty() : tensor<2x8xi32> + %143 = linalg.fill ins(%c0_i32 : i32) outs(%142 : tensor<2x8xi32>) -> tensor<2x8xi32> + %reduced_125 = linalg.reduce ins(%141 : tensor<2x2x8xi32>) outs(%143 : tensor<2x8xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_126 = linalg.broadcast ins(%reduced_125 : tensor<2x8xi32>) outs(%78 : tensor<2x2x8xi32>) dimensions = [1] + %144 = arith.muli %reshape_115, %broadcasted_64 : tensor<2x2x8xi32> + %reduced_127 = linalg.reduce ins(%144 : tensor<2x2x8xi32>) outs(%143 : tensor<2x8xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_128 = linalg.broadcast ins(%reduced_127 : tensor<2x8xi32>) outs(%78 : tensor<2x2x8xi32>) dimensions = [1] + %reshape_129 = tensor.reshape %broadcasted_126(%cst_5) : (tensor<2x2x8xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_130 = tensor.reshape %broadcasted_128(%cst_5) : (tensor<2x2x8xi32>, tensor<1xi64>) -> tensor<32xi32> + %145 = arith.bitcast %reshape_122 : tensor<32xf32> to tensor<32xi32> + %146 = arith.bitcast %reshape_123 : tensor<32xf32> to tensor<32xi32> + %147 = arith.cmpf ogt, %reshape_122, %reshape_123 : tensor<32xf32> + %148 = arith.extui %147 : tensor<32xi1> to tensor<32xi32> + %149 = arith.xori %148, %reshape_113 : tensor<32xi32> + %150 = arith.xori %145, %146 : tensor<32xi32> + %151 = arith.cmpi ne, %149, %1 : tensor<32xi32> + %152 = arith.select %151, %150, %1 : tensor<32xi1>, tensor<32xi32> + %153 = arith.xori %130, %152 : tensor<32xi32> + %154 = arith.xori %reshape_129, %reshape_130 : tensor<32xi32> + %155 = arith.select %151, %154, %1 : tensor<32xi1>, tensor<32xi32> + %156 = arith.xori %133, %155 : tensor<32xi32> + %157 = arith.bitcast %153 : tensor<32xi32> to tensor<32xf32> + %reshape_131 = tensor.reshape %157(%cst_1) : (tensor<32xf32>, tensor<3xi64>) -> tensor<4x2x4xf32> + %reshape_132 = tensor.reshape %156(%cst_1) : (tensor<32xi32>, tensor<3xi64>) -> tensor<4x2x4xi32> + %158 = arith.mulf %reshape_131, %broadcasted_68 : tensor<4x2x4xf32> + %reduced_133 = linalg.reduce ins(%158 : tensor<4x2x4xf32>) outs(%82 : tensor<4x4xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_134 = linalg.broadcast ins(%reduced_133 : tensor<4x4xf32>) outs(%79 : tensor<4x2x4xf32>) dimensions = [1] + %159 = arith.mulf %reshape_131, %broadcasted_71 : tensor<4x2x4xf32> + %reduced_135 = linalg.reduce ins(%159 : tensor<4x2x4xf32>) outs(%82 : tensor<4x4xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_136 = linalg.broadcast ins(%reduced_135 : tensor<4x4xf32>) outs(%79 : tensor<4x2x4xf32>) dimensions = [1] + %reshape_137 = tensor.reshape %broadcasted_134(%cst_5) : (tensor<4x2x4xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_138 = tensor.reshape %broadcasted_136(%cst_5) : (tensor<4x2x4xf32>, tensor<1xi64>) -> tensor<32xf32> + %160 = arith.muli %reshape_132, %broadcasted_76 : tensor<4x2x4xi32> + %reduced_139 = linalg.reduce ins(%160 : tensor<4x2x4xi32>) outs(%86 : tensor<4x4xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_140 = linalg.broadcast ins(%reduced_139 : tensor<4x4xi32>) outs(%38 : tensor<4x2x4xi32>) dimensions = [1] + %161 = arith.muli %reshape_132, %broadcasted_28 : tensor<4x2x4xi32> + %reduced_141 = linalg.reduce ins(%161 : tensor<4x2x4xi32>) outs(%86 : tensor<4x4xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_142 = linalg.broadcast ins(%reduced_141 : tensor<4x4xi32>) outs(%38 : tensor<4x2x4xi32>) dimensions = [1] + %reshape_143 = tensor.reshape %broadcasted_140(%cst_5) : (tensor<4x2x4xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_144 = tensor.reshape %broadcasted_142(%cst_5) : (tensor<4x2x4xi32>, tensor<1xi64>) -> tensor<32xi32> + %162 = arith.bitcast %reshape_137 : tensor<32xf32> to tensor<32xi32> + %163 = arith.bitcast %reshape_138 : tensor<32xf32> to tensor<32xi32> + %164 = arith.cmpf ogt, %reshape_137, %reshape_138 : tensor<32xf32> + %165 = arith.extui %164 : tensor<32xi1> to tensor<32xi32> + %166 = arith.xori %165, %reshape_113 : tensor<32xi32> + %167 = arith.xori %162, %163 : tensor<32xi32> + %168 = arith.cmpi ne, %166, %1 : tensor<32xi32> + %169 = arith.select %168, %167, %1 : tensor<32xi1>, tensor<32xi32> + %170 = arith.xori %153, %169 : tensor<32xi32> + %171 = arith.xori %reshape_143, %reshape_144 : tensor<32xi32> + %172 = arith.select %168, %171, %1 : tensor<32xi1>, tensor<32xi32> + %173 = arith.xori %156, %172 : tensor<32xi32> + %174 = arith.bitcast %170 : tensor<32xi32> to tensor<32xf32> + %reshape_145 = tensor.reshape %174(%cst_2) : (tensor<32xf32>, tensor<3xi64>) -> tensor<8x2x2xf32> + %reshape_146 = tensor.reshape %173(%cst_2) : (tensor<32xi32>, tensor<3xi64>) -> tensor<8x2x2xi32> + %175 = arith.mulf %reshape_145, %broadcasted_33 : tensor<8x2x2xf32> + %reduced_147 = linalg.reduce ins(%175 : tensor<8x2x2xf32>) outs(%42 : tensor<8x2xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_148 = linalg.broadcast ins(%reduced_147 : tensor<8x2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [1] + %176 = arith.mulf %reshape_145, %broadcasted_37 : tensor<8x2x2xf32> + %reduced_149 = linalg.reduce ins(%176 : tensor<8x2x2xf32>) outs(%42 : tensor<8x2xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_150 = linalg.broadcast ins(%reduced_149 : tensor<8x2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [1] + %reshape_151 = tensor.reshape %broadcasted_148(%cst_5) : (tensor<8x2x2xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_152 = tensor.reshape %broadcasted_150(%cst_5) : (tensor<8x2x2xf32>, tensor<1xi64>) -> tensor<32xf32> + %177 = arith.muli %reshape_146, %broadcasted_43 : tensor<8x2x2xi32> + %reduced_153 = linalg.reduce ins(%177 : tensor<8x2x2xi32>) outs(%46 : tensor<8x2xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_154 = linalg.broadcast ins(%reduced_153 : tensor<8x2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [1] + %178 = arith.muli %reshape_146, %broadcasted : tensor<8x2x2xi32> + %reduced_155 = linalg.reduce ins(%178 : tensor<8x2x2xi32>) outs(%46 : tensor<8x2xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_156 = linalg.broadcast ins(%reduced_155 : tensor<8x2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [1] + %reshape_157 = tensor.reshape %broadcasted_154(%cst_5) : (tensor<8x2x2xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_158 = tensor.reshape %broadcasted_156(%cst_5) : (tensor<8x2x2xi32>, tensor<1xi64>) -> tensor<32xi32> + %179 = arith.bitcast %reshape_151 : tensor<32xf32> to tensor<32xi32> + %180 = arith.bitcast %reshape_152 : tensor<32xf32> to tensor<32xi32> + %181 = arith.cmpf ogt, %reshape_151, %reshape_152 : tensor<32xf32> + %182 = arith.extui %181 : tensor<32xi1> to tensor<32xi32> + %183 = arith.xori %182, %reshape_113 : tensor<32xi32> + %184 = arith.xori %179, %180 : tensor<32xi32> + %185 = arith.cmpi ne, %183, %1 : tensor<32xi32> + %186 = arith.select %185, %184, %1 : tensor<32xi1>, tensor<32xi32> + %187 = arith.xori %170, %186 : tensor<32xi32> + %188 = arith.xori %reshape_157, %reshape_158 : tensor<32xi32> + %189 = arith.select %185, %188, %1 : tensor<32xi1>, tensor<32xi32> + %190 = arith.xori %173, %189 : tensor<32xi32> + %191 = arith.bitcast %187 : tensor<32xi32> to tensor<32xf32> + %reshape_159 = tensor.reshape %191(%cst_4) : (tensor<32xf32>, tensor<3xi64>) -> tensor<16x2x1xf32> + %reshape_160 = tensor.reshape %190(%cst_4) : (tensor<32xi32>, tensor<3xi64>) -> tensor<16x2x1xi32> + %192 = arith.mulf %reshape_159, %broadcasted_10 : tensor<16x2x1xf32> + %reduced_161 = linalg.reduce ins(%192 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_162 = linalg.broadcast ins(%reduced_161 : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %193 = arith.mulf %reshape_159, %broadcasted_13 : tensor<16x2x1xf32> + %reduced_163 = linalg.reduce ins(%193 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_164 = linalg.broadcast ins(%reduced_163 : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %reshape_165 = tensor.reshape %broadcasted_162(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_166 = tensor.reshape %broadcasted_164(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %194 = arith.muli %reshape_160, %broadcasted_19 : tensor<16x2x1xi32> + %reduced_167 = linalg.reduce ins(%194 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_168 = linalg.broadcast ins(%reduced_167 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %195 = arith.muli %reshape_160, %broadcasted_23 : tensor<16x2x1xi32> + %reduced_169 = linalg.reduce ins(%195 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_170 = linalg.broadcast ins(%reduced_169 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %reshape_171 = tensor.reshape %broadcasted_168(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_172 = tensor.reshape %broadcasted_170(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %196 = arith.bitcast %reshape_165 : tensor<32xf32> to tensor<32xi32> + %197 = arith.bitcast %reshape_166 : tensor<32xf32> to tensor<32xi32> + %198 = arith.cmpf ogt, %reshape_165, %reshape_166 : tensor<32xf32> + %199 = arith.extui %198 : tensor<32xi1> to tensor<32xi32> + %200 = arith.xori %199, %reshape_113 : tensor<32xi32> + %201 = arith.xori %196, %197 : tensor<32xi32> + %202 = arith.cmpi ne, %200, %1 : tensor<32xi32> + %203 = arith.select %202, %201, %1 : tensor<32xi1>, tensor<32xi32> + %204 = arith.xori %187, %203 : tensor<32xi32> + %205 = arith.xori %reshape_171, %reshape_172 : tensor<32xi32> + %206 = arith.select %202, %205, %1 : tensor<32xi1>, tensor<32xi32> + %207 = arith.xori %190, %206 : tensor<32xi32> + %208 = arith.bitcast %204 : tensor<32xi32> to tensor<32xf32> + %reshape_173 = tensor.reshape %208(%cst) : (tensor<32xf32>, tensor<3xi64>) -> tensor<1x2x16xf32> + %reshape_174 = tensor.reshape %207(%cst) : (tensor<32xi32>, tensor<3xi64>) -> tensor<1x2x16xi32> + %209 = tensor.empty() : tensor<1x2x16xf32> + %collapsed_175 = tensor.collapse_shape %12 [[0], [1, 2]] : tensor<1x2x1xf32> into tensor<1x2xf32> + %broadcasted_176 = linalg.broadcast ins(%collapsed_175 : tensor<1x2xf32>) outs(%209 : tensor<1x2x16xf32>) dimensions = [2] + %210 = arith.mulf %reshape_173, %broadcasted_176 : tensor<1x2x16xf32> + %211 = tensor.empty() : tensor<1x16xf32> + %212 = linalg.fill ins(%cst_3 : f32) outs(%211 : tensor<1x16xf32>) -> tensor<1x16xf32> + %reduced_177 = linalg.reduce ins(%210 : tensor<1x2x16xf32>) outs(%212 : tensor<1x16xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_178 = linalg.broadcast ins(%reduced_177 : tensor<1x16xf32>) outs(%209 : tensor<1x2x16xf32>) dimensions = [1] + %collapsed_179 = tensor.collapse_shape %17 [[0], [1, 2]] : tensor<1x2x1xf32> into tensor<1x2xf32> + %broadcasted_180 = linalg.broadcast ins(%collapsed_179 : tensor<1x2xf32>) outs(%209 : tensor<1x2x16xf32>) dimensions = [2] + %213 = arith.mulf %reshape_173, %broadcasted_180 : tensor<1x2x16xf32> + %reduced_181 = linalg.reduce ins(%213 : tensor<1x2x16xf32>) outs(%212 : tensor<1x16xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_182 = linalg.broadcast ins(%reduced_181 : tensor<1x16xf32>) outs(%209 : tensor<1x2x16xf32>) dimensions = [1] + %reshape_183 = tensor.reshape %broadcasted_178(%cst_5) : (tensor<1x2x16xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_184 = tensor.reshape %broadcasted_182(%cst_5) : (tensor<1x2x16xf32>, tensor<1xi64>) -> tensor<32xf32> + %collapsed_185 = tensor.collapse_shape %11 [[0], [1, 2]] : tensor<1x2x1xi32> into tensor<1x2xi32> + %broadcasted_186 = linalg.broadcast ins(%collapsed_185 : tensor<1x2xi32>) outs(%135 : tensor<1x2x16xi32>) dimensions = [2] + %214 = arith.muli %reshape_174, %broadcasted_186 : tensor<1x2x16xi32> + %215 = tensor.empty() : tensor<1x16xi32> + %216 = linalg.fill ins(%c0_i32 : i32) outs(%215 : tensor<1x16xi32>) -> tensor<1x16xi32> + %reduced_187 = linalg.reduce ins(%214 : tensor<1x2x16xi32>) outs(%216 : tensor<1x16xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_188 = linalg.broadcast ins(%reduced_187 : tensor<1x16xi32>) outs(%135 : tensor<1x2x16xi32>) dimensions = [1] + %217 = arith.muli %reshape_174, %broadcasted_112 : tensor<1x2x16xi32> + %reduced_189 = linalg.reduce ins(%217 : tensor<1x2x16xi32>) outs(%216 : tensor<1x16xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_190 = linalg.broadcast ins(%reduced_189 : tensor<1x16xi32>) outs(%135 : tensor<1x2x16xi32>) dimensions = [1] + %reshape_191 = tensor.reshape %broadcasted_188(%cst_5) : (tensor<1x2x16xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_192 = tensor.reshape %broadcasted_190(%cst_5) : (tensor<1x2x16xi32>, tensor<1xi64>) -> tensor<32xi32> + %218 = arith.bitcast %reshape_183 : tensor<32xf32> to tensor<32xi32> + %219 = arith.bitcast %reshape_184 : tensor<32xf32> to tensor<32xi32> + %220 = arith.cmpf ogt, %reshape_183, %reshape_184 : tensor<32xf32> + %221 = arith.xori %218, %219 : tensor<32xi32> + %222 = arith.select %220, %221, %1 : tensor<32xi1>, tensor<32xi32> + %223 = arith.xori %204, %222 : tensor<32xi32> + %224 = arith.xori %reshape_191, %reshape_192 : tensor<32xi32> + %225 = arith.select %220, %224, %1 : tensor<32xi1>, tensor<32xi32> + %226 = arith.xori %207, %225 : tensor<32xi32> + %227 = arith.bitcast %223 : tensor<32xi32> to tensor<32xf32> + %reshape_193 = tensor.reshape %227(%cst_0) : (tensor<32xf32>, tensor<3xi64>) -> tensor<2x2x8xf32> + %reshape_194 = tensor.reshape %226(%cst_0) : (tensor<32xi32>, tensor<3xi64>) -> tensor<2x2x8xi32> + %228 = arith.mulf %reshape_193, %broadcasted_116 : tensor<2x2x8xf32> + %reduced_195 = linalg.reduce ins(%228 : tensor<2x2x8xf32>) outs(%139 : tensor<2x8xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_196 = linalg.broadcast ins(%reduced_195 : tensor<2x8xf32>) outs(%136 : tensor<2x2x8xf32>) dimensions = [1] + %229 = arith.mulf %reshape_193, %broadcasted_119 : tensor<2x2x8xf32> + %reduced_197 = linalg.reduce ins(%229 : tensor<2x2x8xf32>) outs(%139 : tensor<2x8xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_198 = linalg.broadcast ins(%reduced_197 : tensor<2x8xf32>) outs(%136 : tensor<2x2x8xf32>) dimensions = [1] + %reshape_199 = tensor.reshape %broadcasted_196(%cst_5) : (tensor<2x2x8xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_200 = tensor.reshape %broadcasted_198(%cst_5) : (tensor<2x2x8xf32>, tensor<1xi64>) -> tensor<32xf32> + %230 = arith.muli %reshape_194, %broadcasted_124 : tensor<2x2x8xi32> + %reduced_201 = linalg.reduce ins(%230 : tensor<2x2x8xi32>) outs(%143 : tensor<2x8xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_202 = linalg.broadcast ins(%reduced_201 : tensor<2x8xi32>) outs(%78 : tensor<2x2x8xi32>) dimensions = [1] + %231 = arith.muli %reshape_194, %broadcasted_64 : tensor<2x2x8xi32> + %reduced_203 = linalg.reduce ins(%231 : tensor<2x2x8xi32>) outs(%143 : tensor<2x8xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_204 = linalg.broadcast ins(%reduced_203 : tensor<2x8xi32>) outs(%78 : tensor<2x2x8xi32>) dimensions = [1] + %reshape_205 = tensor.reshape %broadcasted_202(%cst_5) : (tensor<2x2x8xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_206 = tensor.reshape %broadcasted_204(%cst_5) : (tensor<2x2x8xi32>, tensor<1xi64>) -> tensor<32xi32> + %232 = arith.bitcast %reshape_199 : tensor<32xf32> to tensor<32xi32> + %233 = arith.bitcast %reshape_200 : tensor<32xf32> to tensor<32xi32> + %234 = arith.cmpf ogt, %reshape_199, %reshape_200 : tensor<32xf32> + %235 = arith.xori %232, %233 : tensor<32xi32> + %236 = arith.select %234, %235, %1 : tensor<32xi1>, tensor<32xi32> + %237 = arith.xori %223, %236 : tensor<32xi32> + %238 = arith.xori %reshape_205, %reshape_206 : tensor<32xi32> + %239 = arith.select %234, %238, %1 : tensor<32xi1>, tensor<32xi32> + %240 = arith.xori %226, %239 : tensor<32xi32> + %241 = arith.bitcast %237 : tensor<32xi32> to tensor<32xf32> + %reshape_207 = tensor.reshape %241(%cst_1) : (tensor<32xf32>, tensor<3xi64>) -> tensor<4x2x4xf32> + %reshape_208 = tensor.reshape %240(%cst_1) : (tensor<32xi32>, tensor<3xi64>) -> tensor<4x2x4xi32> + %242 = arith.mulf %reshape_207, %broadcasted_68 : tensor<4x2x4xf32> + %reduced_209 = linalg.reduce ins(%242 : tensor<4x2x4xf32>) outs(%82 : tensor<4x4xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_210 = linalg.broadcast ins(%reduced_209 : tensor<4x4xf32>) outs(%79 : tensor<4x2x4xf32>) dimensions = [1] + %243 = arith.mulf %reshape_207, %broadcasted_71 : tensor<4x2x4xf32> + %reduced_211 = linalg.reduce ins(%243 : tensor<4x2x4xf32>) outs(%82 : tensor<4x4xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_212 = linalg.broadcast ins(%reduced_211 : tensor<4x4xf32>) outs(%79 : tensor<4x2x4xf32>) dimensions = [1] + %reshape_213 = tensor.reshape %broadcasted_210(%cst_5) : (tensor<4x2x4xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_214 = tensor.reshape %broadcasted_212(%cst_5) : (tensor<4x2x4xf32>, tensor<1xi64>) -> tensor<32xf32> + %244 = arith.muli %reshape_208, %broadcasted_76 : tensor<4x2x4xi32> + %reduced_215 = linalg.reduce ins(%244 : tensor<4x2x4xi32>) outs(%86 : tensor<4x4xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_216 = linalg.broadcast ins(%reduced_215 : tensor<4x4xi32>) outs(%38 : tensor<4x2x4xi32>) dimensions = [1] + %245 = arith.muli %reshape_208, %broadcasted_28 : tensor<4x2x4xi32> + %reduced_217 = linalg.reduce ins(%245 : tensor<4x2x4xi32>) outs(%86 : tensor<4x4xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_218 = linalg.broadcast ins(%reduced_217 : tensor<4x4xi32>) outs(%38 : tensor<4x2x4xi32>) dimensions = [1] + %reshape_219 = tensor.reshape %broadcasted_216(%cst_5) : (tensor<4x2x4xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_220 = tensor.reshape %broadcasted_218(%cst_5) : (tensor<4x2x4xi32>, tensor<1xi64>) -> tensor<32xi32> + %246 = arith.bitcast %reshape_213 : tensor<32xf32> to tensor<32xi32> + %247 = arith.bitcast %reshape_214 : tensor<32xf32> to tensor<32xi32> + %248 = arith.cmpf ogt, %reshape_213, %reshape_214 : tensor<32xf32> + %249 = arith.xori %246, %247 : tensor<32xi32> + %250 = arith.select %248, %249, %1 : tensor<32xi1>, tensor<32xi32> + %251 = arith.xori %237, %250 : tensor<32xi32> + %252 = arith.xori %reshape_219, %reshape_220 : tensor<32xi32> + %253 = arith.select %248, %252, %1 : tensor<32xi1>, tensor<32xi32> + %254 = arith.xori %240, %253 : tensor<32xi32> + %255 = arith.bitcast %251 : tensor<32xi32> to tensor<32xf32> + %reshape_221 = tensor.reshape %255(%cst_2) : (tensor<32xf32>, tensor<3xi64>) -> tensor<8x2x2xf32> + %reshape_222 = tensor.reshape %254(%cst_2) : (tensor<32xi32>, tensor<3xi64>) -> tensor<8x2x2xi32> + %256 = arith.mulf %reshape_221, %broadcasted_33 : tensor<8x2x2xf32> + %reduced_223 = linalg.reduce ins(%256 : tensor<8x2x2xf32>) outs(%42 : tensor<8x2xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_224 = linalg.broadcast ins(%reduced_223 : tensor<8x2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [1] + %257 = arith.mulf %reshape_221, %broadcasted_37 : tensor<8x2x2xf32> + %reduced_225 = linalg.reduce ins(%257 : tensor<8x2x2xf32>) outs(%42 : tensor<8x2xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_226 = linalg.broadcast ins(%reduced_225 : tensor<8x2xf32>) outs(%39 : tensor<8x2x2xf32>) dimensions = [1] + %reshape_227 = tensor.reshape %broadcasted_224(%cst_5) : (tensor<8x2x2xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_228 = tensor.reshape %broadcasted_226(%cst_5) : (tensor<8x2x2xf32>, tensor<1xi64>) -> tensor<32xf32> + %258 = arith.muli %reshape_222, %broadcasted_43 : tensor<8x2x2xi32> + %reduced_229 = linalg.reduce ins(%258 : tensor<8x2x2xi32>) outs(%46 : tensor<8x2xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_230 = linalg.broadcast ins(%reduced_229 : tensor<8x2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [1] + %259 = arith.muli %reshape_222, %broadcasted : tensor<8x2x2xi32> + %reduced_231 = linalg.reduce ins(%259 : tensor<8x2x2xi32>) outs(%46 : tensor<8x2xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_232 = linalg.broadcast ins(%reduced_231 : tensor<8x2xi32>) outs(%10 : tensor<8x2x2xi32>) dimensions = [1] + %reshape_233 = tensor.reshape %broadcasted_230(%cst_5) : (tensor<8x2x2xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_234 = tensor.reshape %broadcasted_232(%cst_5) : (tensor<8x2x2xi32>, tensor<1xi64>) -> tensor<32xi32> + %260 = arith.bitcast %reshape_227 : tensor<32xf32> to tensor<32xi32> + %261 = arith.bitcast %reshape_228 : tensor<32xf32> to tensor<32xi32> + %262 = arith.cmpf ogt, %reshape_227, %reshape_228 : tensor<32xf32> + %263 = arith.xori %260, %261 : tensor<32xi32> + %264 = arith.select %262, %263, %1 : tensor<32xi1>, tensor<32xi32> + %265 = arith.xori %251, %264 : tensor<32xi32> + %266 = arith.xori %reshape_233, %reshape_234 : tensor<32xi32> + %267 = arith.select %262, %266, %1 : tensor<32xi1>, tensor<32xi32> + %268 = arith.xori %254, %267 : tensor<32xi32> + %269 = arith.bitcast %265 : tensor<32xi32> to tensor<32xf32> + %reshape_235 = tensor.reshape %269(%cst_4) : (tensor<32xf32>, tensor<3xi64>) -> tensor<16x2x1xf32> + %reshape_236 = tensor.reshape %268(%cst_4) : (tensor<32xi32>, tensor<3xi64>) -> tensor<16x2x1xi32> + %270 = arith.mulf %reshape_235, %broadcasted_10 : tensor<16x2x1xf32> + %reduced_237 = linalg.reduce ins(%270 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_238 = linalg.broadcast ins(%reduced_237 : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %271 = arith.mulf %reshape_235, %broadcasted_13 : tensor<16x2x1xf32> + %reduced_239 = linalg.reduce ins(%271 : tensor<16x2x1xf32>) outs(%16 : tensor<16x1xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %285 = arith.addf %in, %init : f32 + linalg.yield %285 : f32 + } + %broadcasted_240 = linalg.broadcast ins(%reduced_239 : tensor<16x1xf32>) outs(%13 : tensor<16x2x1xf32>) dimensions = [1] + %reshape_241 = tensor.reshape %broadcasted_238(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %reshape_242 = tensor.reshape %broadcasted_240(%cst_5) : (tensor<16x2x1xf32>, tensor<1xi64>) -> tensor<32xf32> + %272 = arith.muli %reshape_236, %broadcasted_19 : tensor<16x2x1xi32> + %reduced_243 = linalg.reduce ins(%272 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_244 = linalg.broadcast ins(%reduced_243 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %273 = arith.muli %reshape_236, %broadcasted_23 : tensor<16x2x1xi32> + %reduced_245 = linalg.reduce ins(%273 : tensor<16x2x1xi32>) outs(%22 : tensor<16x1xi32>) dimensions = [1] + (%in: i32, %init: i32) { + %285 = arith.addi %in, %init : i32 + linalg.yield %285 : i32 + } + %broadcasted_246 = linalg.broadcast ins(%reduced_245 : tensor<16x1xi32>) outs(%19 : tensor<16x2x1xi32>) dimensions = [1] + %reshape_247 = tensor.reshape %broadcasted_244(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %reshape_248 = tensor.reshape %broadcasted_246(%cst_5) : (tensor<16x2x1xi32>, tensor<1xi64>) -> tensor<32xi32> + %274 = arith.bitcast %reshape_241 : tensor<32xf32> to tensor<32xi32> + %275 = arith.bitcast %reshape_242 : tensor<32xf32> to tensor<32xi32> + %276 = arith.cmpf ogt, %reshape_241, %reshape_242 : tensor<32xf32> + %277 = arith.xori %274, %275 : tensor<32xi32> + %278 = arith.select %276, %277, %1 : tensor<32xi1>, tensor<32xi32> + %279 = arith.xori %265, %278 : tensor<32xi32> + %280 = arith.xori %reshape_247, %reshape_248 : tensor<32xi32> + %281 = arith.select %276, %280, %1 : tensor<32xi1>, tensor<32xi32> + %282 = arith.xori %268, %281 : tensor<32xi32> + %283 = arith.bitcast %279 : tensor<32xi32> to tensor<32xf32> + bufferization.materialize_in_destination %283 in writable %reinterpret_cast_6 : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () + %284 = arith.extsi %282 : tensor<32xi32> to tensor<32xi64> + bufferization.materialize_in_destination %284 in writable %reinterpret_cast_7 : (tensor<32xi64>, memref<32xi64, strided<[1], offset: ?>>) -> () + return + } +} + diff --git a/python/test/ops/sort/triton-ascend-failed/sort_kernel.ttir b/python/test/ops/sort/triton-ascend-failed/sort_kernel.ttir new file mode 100644 index 000000000..7a0b7b59e --- /dev/null +++ b/python/test/ops/sort/triton-ascend-failed/sort_kernel.ttir @@ -0,0 +1,1029 @@ +#loc = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":172:0) +#loc1 = loc(unknown) +#loc12 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":167:76) +#loc13 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":198:27) +#loc18 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":157:71) +#loc23 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":80:52) +#loc28 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":81:47) +#loc34 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":85:60) +#loc38 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":88:55) +#loc67 = loc(callsite(#loc1 at #loc23)) +#loc73 = loc(callsite(#loc1 at #loc28)) +#loc80 = loc(callsite(#loc1 at #loc34)) +#loc85 = loc(callsite(#loc1 at #loc38)) +#loc111 = loc(callsite(#loc67 at #loc18)) +#loc117 = loc(callsite(#loc73 at #loc18)) +#loc125 = loc(callsite(#loc80 at #loc18)) +#loc131 = loc(callsite(#loc85 at #loc18)) +#loc154 = loc(callsite(#loc111 at #loc12)) +#loc160 = loc(callsite(#loc117 at #loc12)) +#loc168 = loc(callsite(#loc125 at #loc12)) +#loc174 = loc(callsite(#loc131 at #loc12)) +#loc193 = loc(callsite(#loc154 at #loc13)) +#loc196 = loc(callsite(#loc160 at #loc13)) +#loc199 = loc(callsite(#loc168 at #loc13)) +#loc202 = loc(callsite(#loc174 at #loc13)) +module { + tt.func public @sort_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":172:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":172:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":172:0)) attributes {noinline = false} { + %cst = arith.constant dense<3.40282347E+38> : tensor<32xf32> loc(#loc1) + %cst_0 = arith.constant dense<0> : tensor<32xi32> loc(#loc1) + %cst_1 = arith.constant dense<1> : tensor<1x2x1xi32> loc(#loc1) + %cst_2 = arith.constant dense<32> : tensor<32xi32> loc(#loc1) + %c32_i32 = arith.constant 32 : i32 loc(#loc1) + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> loc(#loc2) + %1 = arith.cmpi slt, %0, %cst_2 : tensor<32xi32> loc(#loc3) + %2 = tt.get_program_id x : i32 loc(#loc4) + %3 = arith.muli %2, %c32_i32 : i32 loc(#loc5) + %4 = tt.splat %3 : i32 -> tensor<32xi32> loc(#loc6) + %5 = arith.addi %4, %0 : tensor<32xi32> loc(#loc6) + %6 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> loc(#loc7) + %7 = tt.addptr %6, %5 : tensor<32x!tt.ptr>, tensor<32xi32> loc(#loc7) + %8 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> loc(#loc8) + %9 = tt.addptr %8, %5 : tensor<32x!tt.ptr>, tensor<32xi32> loc(#loc8) + %10 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> loc(#loc9) + %11 = tt.addptr %10, %5 : tensor<32x!tt.ptr>, tensor<32xi32> loc(#loc9) + %12 = tt.load %7, %1, %cst : tensor<32x!tt.ptr> loc(#loc10) + %13 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> loc(#loc102) + %14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> loc(#loc103) + %15 = tt.expand_dims %14 {axis = 2 : i32} : tensor<1x2xi32> -> tensor<1x2x1xi32> loc(#loc103) + %16 = tt.broadcast %15 : tensor<1x2x1xi32> -> tensor<8x2x2xi32> loc(#loc104) + %17 = tt.reshape %16 : tensor<8x2x2xi32> -> tensor<32xi32> loc(#loc105) + %18 = tt.reshape %12 : tensor<32xf32> -> tensor<16x2x1xf32> loc(#loc149) + %19 = tt.reshape %0 : tensor<32xi32> -> tensor<16x2x1xi32> loc(#loc150) + %20 = arith.subi %cst_1, %15 : tensor<1x2x1xi32> loc(#loc151) + %21 = arith.sitofp %20 : tensor<1x2x1xi32> to tensor<1x2x1xf32> loc(#loc152) + %22 = tt.broadcast %21 : tensor<1x2x1xf32> -> tensor<16x2x1xf32> loc(#loc152) + %23 = arith.mulf %18, %22 : tensor<16x2x1xf32> loc(#loc152) + %24 = "tt.reduce"(%23) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc192) + %25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc156) + %26 = tt.broadcast %25 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc157) + %27 = arith.sitofp %15 : tensor<1x2x1xi32> to tensor<1x2x1xf32> loc(#loc158) + %28 = tt.broadcast %27 : tensor<1x2x1xf32> -> tensor<16x2x1xf32> loc(#loc158) + %29 = arith.mulf %18, %28 : tensor<16x2x1xf32> loc(#loc158) + %30 = "tt.reduce"(%29) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc195) + %31 = tt.expand_dims %30 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc162) + %32 = tt.broadcast %31 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc163) + %33 = tt.reshape %26 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc164) + %34 = tt.reshape %32 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc165) + %35 = tt.broadcast %20 : tensor<1x2x1xi32> -> tensor<16x2x1xi32> loc(#loc166) + %36 = arith.muli %19, %35 : tensor<16x2x1xi32> loc(#loc166) + %37 = "tt.reduce"(%36) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc198) + %38 = tt.expand_dims %37 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc170) + %39 = tt.broadcast %38 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc171) + %40 = tt.broadcast %15 : tensor<1x2x1xi32> -> tensor<16x2x1xi32> loc(#loc172) + %41 = arith.muli %19, %40 : tensor<16x2x1xi32> loc(#loc172) + %42 = "tt.reduce"(%41) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc201) + %43 = tt.expand_dims %42 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc176) + %44 = tt.broadcast %43 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc177) + %45 = tt.reshape %39 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc178) + %46 = tt.reshape %44 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc179) + %47 = tt.bitcast %33 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %48 = tt.bitcast %34 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %49 = tt.bitcast %12 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %50 = arith.cmpf ogt, %33, %34 : tensor<32xf32> loc(#loc183) + %51 = arith.extui %50 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %52 = arith.xori %51, %17 : tensor<32xi32> loc(#loc184) + %53 = arith.xori %47, %48 : tensor<32xi32> loc(#loc185) + %54 = arith.cmpi ne, %52, %cst_0 : tensor<32xi32> loc(#loc186) + %55 = arith.select %54, %53, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %56 = arith.xori %49, %55 : tensor<32xi32> loc(#loc187) + %57 = arith.xori %45, %46 : tensor<32xi32> loc(#loc188) + %58 = arith.select %54, %57, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %59 = arith.xori %0, %58 : tensor<32xi32> loc(#loc190) + %60 = tt.bitcast %56 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %61 = tt.broadcast %15 : tensor<1x2x1xi32> -> tensor<4x2x4xi32> loc(#loc104) + %62 = tt.reshape %61 : tensor<4x2x4xi32> -> tensor<32xi32> loc(#loc105) + %63 = tt.reshape %60 : tensor<32xf32> -> tensor<8x2x2xf32> loc(#loc149) + %64 = tt.reshape %59 : tensor<32xi32> -> tensor<8x2x2xi32> loc(#loc150) + %65 = tt.broadcast %21 : tensor<1x2x1xf32> -> tensor<8x2x2xf32> loc(#loc152) + %66 = arith.mulf %63, %65 : tensor<8x2x2xf32> loc(#loc152) + %67 = "tt.reduce"(%66) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<8x2x2xf32>) -> tensor<8x2xf32> loc(#loc192) + %68 = tt.expand_dims %67 {axis = 1 : i32} : tensor<8x2xf32> -> tensor<8x1x2xf32> loc(#loc156) + %69 = tt.broadcast %68 : tensor<8x1x2xf32> -> tensor<8x2x2xf32> loc(#loc157) + %70 = tt.broadcast %27 : tensor<1x2x1xf32> -> tensor<8x2x2xf32> loc(#loc158) + %71 = arith.mulf %63, %70 : tensor<8x2x2xf32> loc(#loc158) + %72 = "tt.reduce"(%71) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<8x2x2xf32>) -> tensor<8x2xf32> loc(#loc195) + %73 = tt.expand_dims %72 {axis = 1 : i32} : tensor<8x2xf32> -> tensor<8x1x2xf32> loc(#loc162) + %74 = tt.broadcast %73 : tensor<8x1x2xf32> -> tensor<8x2x2xf32> loc(#loc163) + %75 = tt.reshape %69 : tensor<8x2x2xf32> -> tensor<32xf32> loc(#loc164) + %76 = tt.reshape %74 : tensor<8x2x2xf32> -> tensor<32xf32> loc(#loc165) + %77 = tt.broadcast %20 : tensor<1x2x1xi32> -> tensor<8x2x2xi32> loc(#loc166) + %78 = arith.muli %64, %77 : tensor<8x2x2xi32> loc(#loc166) + %79 = "tt.reduce"(%78) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<8x2x2xi32>) -> tensor<8x2xi32> loc(#loc198) + %80 = tt.expand_dims %79 {axis = 1 : i32} : tensor<8x2xi32> -> tensor<8x1x2xi32> loc(#loc170) + %81 = tt.broadcast %80 : tensor<8x1x2xi32> -> tensor<8x2x2xi32> loc(#loc171) + %82 = arith.muli %64, %16 : tensor<8x2x2xi32> loc(#loc172) + %83 = "tt.reduce"(%82) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<8x2x2xi32>) -> tensor<8x2xi32> loc(#loc201) + %84 = tt.expand_dims %83 {axis = 1 : i32} : tensor<8x2xi32> -> tensor<8x1x2xi32> loc(#loc176) + %85 = tt.broadcast %84 : tensor<8x1x2xi32> -> tensor<8x2x2xi32> loc(#loc177) + %86 = tt.reshape %81 : tensor<8x2x2xi32> -> tensor<32xi32> loc(#loc178) + %87 = tt.reshape %85 : tensor<8x2x2xi32> -> tensor<32xi32> loc(#loc179) + %88 = tt.bitcast %75 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %89 = tt.bitcast %76 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %90 = tt.bitcast %60 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %91 = arith.cmpf ogt, %75, %76 : tensor<32xf32> loc(#loc183) + %92 = arith.extui %91 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %93 = arith.xori %92, %62 : tensor<32xi32> loc(#loc184) + %94 = arith.xori %88, %89 : tensor<32xi32> loc(#loc185) + %95 = arith.cmpi ne, %93, %cst_0 : tensor<32xi32> loc(#loc186) + %96 = arith.select %95, %94, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %97 = arith.xori %90, %96 : tensor<32xi32> loc(#loc187) + %98 = arith.xori %86, %87 : tensor<32xi32> loc(#loc188) + %99 = arith.select %95, %98, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %100 = arith.xori %59, %99 : tensor<32xi32> loc(#loc190) + %101 = tt.bitcast %97 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %102 = tt.reshape %101 : tensor<32xf32> -> tensor<16x2x1xf32> loc(#loc149) + %103 = tt.reshape %100 : tensor<32xi32> -> tensor<16x2x1xi32> loc(#loc150) + %104 = arith.mulf %102, %22 : tensor<16x2x1xf32> loc(#loc152) + %105 = "tt.reduce"(%104) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc192) + %106 = tt.expand_dims %105 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc156) + %107 = tt.broadcast %106 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc157) + %108 = arith.mulf %102, %28 : tensor<16x2x1xf32> loc(#loc158) + %109 = "tt.reduce"(%108) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc195) + %110 = tt.expand_dims %109 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc162) + %111 = tt.broadcast %110 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc163) + %112 = tt.reshape %107 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc164) + %113 = tt.reshape %111 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc165) + %114 = arith.muli %103, %35 : tensor<16x2x1xi32> loc(#loc166) + %115 = "tt.reduce"(%114) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc198) + %116 = tt.expand_dims %115 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc170) + %117 = tt.broadcast %116 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc171) + %118 = arith.muli %103, %40 : tensor<16x2x1xi32> loc(#loc172) + %119 = "tt.reduce"(%118) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc201) + %120 = tt.expand_dims %119 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc176) + %121 = tt.broadcast %120 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc177) + %122 = tt.reshape %117 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc178) + %123 = tt.reshape %121 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc179) + %124 = tt.bitcast %112 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %125 = tt.bitcast %113 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %126 = tt.bitcast %101 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %127 = arith.cmpf ogt, %112, %113 : tensor<32xf32> loc(#loc183) + %128 = arith.extui %127 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %129 = arith.xori %128, %62 : tensor<32xi32> loc(#loc184) + %130 = arith.xori %124, %125 : tensor<32xi32> loc(#loc185) + %131 = arith.cmpi ne, %129, %cst_0 : tensor<32xi32> loc(#loc186) + %132 = arith.select %131, %130, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %133 = arith.xori %126, %132 : tensor<32xi32> loc(#loc187) + %134 = arith.xori %122, %123 : tensor<32xi32> loc(#loc188) + %135 = arith.select %131, %134, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %136 = arith.xori %100, %135 : tensor<32xi32> loc(#loc190) + %137 = tt.bitcast %133 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %138 = tt.broadcast %15 : tensor<1x2x1xi32> -> tensor<2x2x8xi32> loc(#loc104) + %139 = tt.reshape %138 : tensor<2x2x8xi32> -> tensor<32xi32> loc(#loc105) + %140 = tt.reshape %137 : tensor<32xf32> -> tensor<4x2x4xf32> loc(#loc149) + %141 = tt.reshape %136 : tensor<32xi32> -> tensor<4x2x4xi32> loc(#loc150) + %142 = tt.broadcast %21 : tensor<1x2x1xf32> -> tensor<4x2x4xf32> loc(#loc152) + %143 = arith.mulf %140, %142 : tensor<4x2x4xf32> loc(#loc152) + %144 = "tt.reduce"(%143) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<4x2x4xf32>) -> tensor<4x4xf32> loc(#loc192) + %145 = tt.expand_dims %144 {axis = 1 : i32} : tensor<4x4xf32> -> tensor<4x1x4xf32> loc(#loc156) + %146 = tt.broadcast %145 : tensor<4x1x4xf32> -> tensor<4x2x4xf32> loc(#loc157) + %147 = tt.broadcast %27 : tensor<1x2x1xf32> -> tensor<4x2x4xf32> loc(#loc158) + %148 = arith.mulf %140, %147 : tensor<4x2x4xf32> loc(#loc158) + %149 = "tt.reduce"(%148) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<4x2x4xf32>) -> tensor<4x4xf32> loc(#loc195) + %150 = tt.expand_dims %149 {axis = 1 : i32} : tensor<4x4xf32> -> tensor<4x1x4xf32> loc(#loc162) + %151 = tt.broadcast %150 : tensor<4x1x4xf32> -> tensor<4x2x4xf32> loc(#loc163) + %152 = tt.reshape %146 : tensor<4x2x4xf32> -> tensor<32xf32> loc(#loc164) + %153 = tt.reshape %151 : tensor<4x2x4xf32> -> tensor<32xf32> loc(#loc165) + %154 = tt.broadcast %20 : tensor<1x2x1xi32> -> tensor<4x2x4xi32> loc(#loc166) + %155 = arith.muli %141, %154 : tensor<4x2x4xi32> loc(#loc166) + %156 = "tt.reduce"(%155) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<4x2x4xi32>) -> tensor<4x4xi32> loc(#loc198) + %157 = tt.expand_dims %156 {axis = 1 : i32} : tensor<4x4xi32> -> tensor<4x1x4xi32> loc(#loc170) + %158 = tt.broadcast %157 : tensor<4x1x4xi32> -> tensor<4x2x4xi32> loc(#loc171) + %159 = arith.muli %141, %61 : tensor<4x2x4xi32> loc(#loc172) + %160 = "tt.reduce"(%159) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<4x2x4xi32>) -> tensor<4x4xi32> loc(#loc201) + %161 = tt.expand_dims %160 {axis = 1 : i32} : tensor<4x4xi32> -> tensor<4x1x4xi32> loc(#loc176) + %162 = tt.broadcast %161 : tensor<4x1x4xi32> -> tensor<4x2x4xi32> loc(#loc177) + %163 = tt.reshape %158 : tensor<4x2x4xi32> -> tensor<32xi32> loc(#loc178) + %164 = tt.reshape %162 : tensor<4x2x4xi32> -> tensor<32xi32> loc(#loc179) + %165 = tt.bitcast %152 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %166 = tt.bitcast %153 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %167 = tt.bitcast %137 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %168 = arith.cmpf ogt, %152, %153 : tensor<32xf32> loc(#loc183) + %169 = arith.extui %168 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %170 = arith.xori %169, %139 : tensor<32xi32> loc(#loc184) + %171 = arith.xori %165, %166 : tensor<32xi32> loc(#loc185) + %172 = arith.cmpi ne, %170, %cst_0 : tensor<32xi32> loc(#loc186) + %173 = arith.select %172, %171, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %174 = arith.xori %167, %173 : tensor<32xi32> loc(#loc187) + %175 = arith.xori %163, %164 : tensor<32xi32> loc(#loc188) + %176 = arith.select %172, %175, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %177 = arith.xori %136, %176 : tensor<32xi32> loc(#loc190) + %178 = tt.bitcast %174 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %179 = tt.reshape %178 : tensor<32xf32> -> tensor<8x2x2xf32> loc(#loc149) + %180 = tt.reshape %177 : tensor<32xi32> -> tensor<8x2x2xi32> loc(#loc150) + %181 = arith.mulf %179, %65 : tensor<8x2x2xf32> loc(#loc152) + %182 = "tt.reduce"(%181) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<8x2x2xf32>) -> tensor<8x2xf32> loc(#loc192) + %183 = tt.expand_dims %182 {axis = 1 : i32} : tensor<8x2xf32> -> tensor<8x1x2xf32> loc(#loc156) + %184 = tt.broadcast %183 : tensor<8x1x2xf32> -> tensor<8x2x2xf32> loc(#loc157) + %185 = arith.mulf %179, %70 : tensor<8x2x2xf32> loc(#loc158) + %186 = "tt.reduce"(%185) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<8x2x2xf32>) -> tensor<8x2xf32> loc(#loc195) + %187 = tt.expand_dims %186 {axis = 1 : i32} : tensor<8x2xf32> -> tensor<8x1x2xf32> loc(#loc162) + %188 = tt.broadcast %187 : tensor<8x1x2xf32> -> tensor<8x2x2xf32> loc(#loc163) + %189 = tt.reshape %184 : tensor<8x2x2xf32> -> tensor<32xf32> loc(#loc164) + %190 = tt.reshape %188 : tensor<8x2x2xf32> -> tensor<32xf32> loc(#loc165) + %191 = arith.muli %180, %77 : tensor<8x2x2xi32> loc(#loc166) + %192 = "tt.reduce"(%191) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<8x2x2xi32>) -> tensor<8x2xi32> loc(#loc198) + %193 = tt.expand_dims %192 {axis = 1 : i32} : tensor<8x2xi32> -> tensor<8x1x2xi32> loc(#loc170) + %194 = tt.broadcast %193 : tensor<8x1x2xi32> -> tensor<8x2x2xi32> loc(#loc171) + %195 = arith.muli %180, %16 : tensor<8x2x2xi32> loc(#loc172) + %196 = "tt.reduce"(%195) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<8x2x2xi32>) -> tensor<8x2xi32> loc(#loc201) + %197 = tt.expand_dims %196 {axis = 1 : i32} : tensor<8x2xi32> -> tensor<8x1x2xi32> loc(#loc176) + %198 = tt.broadcast %197 : tensor<8x1x2xi32> -> tensor<8x2x2xi32> loc(#loc177) + %199 = tt.reshape %194 : tensor<8x2x2xi32> -> tensor<32xi32> loc(#loc178) + %200 = tt.reshape %198 : tensor<8x2x2xi32> -> tensor<32xi32> loc(#loc179) + %201 = tt.bitcast %189 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %202 = tt.bitcast %190 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %203 = tt.bitcast %178 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %204 = arith.cmpf ogt, %189, %190 : tensor<32xf32> loc(#loc183) + %205 = arith.extui %204 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %206 = arith.xori %205, %139 : tensor<32xi32> loc(#loc184) + %207 = arith.xori %201, %202 : tensor<32xi32> loc(#loc185) + %208 = arith.cmpi ne, %206, %cst_0 : tensor<32xi32> loc(#loc186) + %209 = arith.select %208, %207, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %210 = arith.xori %203, %209 : tensor<32xi32> loc(#loc187) + %211 = arith.xori %199, %200 : tensor<32xi32> loc(#loc188) + %212 = arith.select %208, %211, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %213 = arith.xori %177, %212 : tensor<32xi32> loc(#loc190) + %214 = tt.bitcast %210 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %215 = tt.reshape %214 : tensor<32xf32> -> tensor<16x2x1xf32> loc(#loc149) + %216 = tt.reshape %213 : tensor<32xi32> -> tensor<16x2x1xi32> loc(#loc150) + %217 = arith.mulf %215, %22 : tensor<16x2x1xf32> loc(#loc152) + %218 = "tt.reduce"(%217) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc192) + %219 = tt.expand_dims %218 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc156) + %220 = tt.broadcast %219 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc157) + %221 = arith.mulf %215, %28 : tensor<16x2x1xf32> loc(#loc158) + %222 = "tt.reduce"(%221) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc195) + %223 = tt.expand_dims %222 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc162) + %224 = tt.broadcast %223 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc163) + %225 = tt.reshape %220 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc164) + %226 = tt.reshape %224 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc165) + %227 = arith.muli %216, %35 : tensor<16x2x1xi32> loc(#loc166) + %228 = "tt.reduce"(%227) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc198) + %229 = tt.expand_dims %228 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc170) + %230 = tt.broadcast %229 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc171) + %231 = arith.muli %216, %40 : tensor<16x2x1xi32> loc(#loc172) + %232 = "tt.reduce"(%231) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc201) + %233 = tt.expand_dims %232 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc176) + %234 = tt.broadcast %233 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc177) + %235 = tt.reshape %230 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc178) + %236 = tt.reshape %234 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc179) + %237 = tt.bitcast %225 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %238 = tt.bitcast %226 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %239 = tt.bitcast %214 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %240 = arith.cmpf ogt, %225, %226 : tensor<32xf32> loc(#loc183) + %241 = arith.extui %240 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %242 = arith.xori %241, %139 : tensor<32xi32> loc(#loc184) + %243 = arith.xori %237, %238 : tensor<32xi32> loc(#loc185) + %244 = arith.cmpi ne, %242, %cst_0 : tensor<32xi32> loc(#loc186) + %245 = arith.select %244, %243, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %246 = arith.xori %239, %245 : tensor<32xi32> loc(#loc187) + %247 = arith.xori %235, %236 : tensor<32xi32> loc(#loc188) + %248 = arith.select %244, %247, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %249 = arith.xori %213, %248 : tensor<32xi32> loc(#loc190) + %250 = tt.bitcast %246 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %251 = tt.broadcast %15 : tensor<1x2x1xi32> -> tensor<1x2x16xi32> loc(#loc104) + %252 = tt.reshape %251 : tensor<1x2x16xi32> -> tensor<32xi32> loc(#loc105) + %253 = tt.reshape %250 : tensor<32xf32> -> tensor<2x2x8xf32> loc(#loc149) + %254 = tt.reshape %249 : tensor<32xi32> -> tensor<2x2x8xi32> loc(#loc150) + %255 = tt.broadcast %21 : tensor<1x2x1xf32> -> tensor<2x2x8xf32> loc(#loc152) + %256 = arith.mulf %253, %255 : tensor<2x2x8xf32> loc(#loc152) + %257 = "tt.reduce"(%256) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<2x2x8xf32>) -> tensor<2x8xf32> loc(#loc192) + %258 = tt.expand_dims %257 {axis = 1 : i32} : tensor<2x8xf32> -> tensor<2x1x8xf32> loc(#loc156) + %259 = tt.broadcast %258 : tensor<2x1x8xf32> -> tensor<2x2x8xf32> loc(#loc157) + %260 = tt.broadcast %27 : tensor<1x2x1xf32> -> tensor<2x2x8xf32> loc(#loc158) + %261 = arith.mulf %253, %260 : tensor<2x2x8xf32> loc(#loc158) + %262 = "tt.reduce"(%261) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<2x2x8xf32>) -> tensor<2x8xf32> loc(#loc195) + %263 = tt.expand_dims %262 {axis = 1 : i32} : tensor<2x8xf32> -> tensor<2x1x8xf32> loc(#loc162) + %264 = tt.broadcast %263 : tensor<2x1x8xf32> -> tensor<2x2x8xf32> loc(#loc163) + %265 = tt.reshape %259 : tensor<2x2x8xf32> -> tensor<32xf32> loc(#loc164) + %266 = tt.reshape %264 : tensor<2x2x8xf32> -> tensor<32xf32> loc(#loc165) + %267 = tt.broadcast %20 : tensor<1x2x1xi32> -> tensor<2x2x8xi32> loc(#loc166) + %268 = arith.muli %254, %267 : tensor<2x2x8xi32> loc(#loc166) + %269 = "tt.reduce"(%268) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<2x2x8xi32>) -> tensor<2x8xi32> loc(#loc198) + %270 = tt.expand_dims %269 {axis = 1 : i32} : tensor<2x8xi32> -> tensor<2x1x8xi32> loc(#loc170) + %271 = tt.broadcast %270 : tensor<2x1x8xi32> -> tensor<2x2x8xi32> loc(#loc171) + %272 = arith.muli %254, %138 : tensor<2x2x8xi32> loc(#loc172) + %273 = "tt.reduce"(%272) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<2x2x8xi32>) -> tensor<2x8xi32> loc(#loc201) + %274 = tt.expand_dims %273 {axis = 1 : i32} : tensor<2x8xi32> -> tensor<2x1x8xi32> loc(#loc176) + %275 = tt.broadcast %274 : tensor<2x1x8xi32> -> tensor<2x2x8xi32> loc(#loc177) + %276 = tt.reshape %271 : tensor<2x2x8xi32> -> tensor<32xi32> loc(#loc178) + %277 = tt.reshape %275 : tensor<2x2x8xi32> -> tensor<32xi32> loc(#loc179) + %278 = tt.bitcast %265 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %279 = tt.bitcast %266 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %280 = tt.bitcast %250 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %281 = arith.cmpf ogt, %265, %266 : tensor<32xf32> loc(#loc183) + %282 = arith.extui %281 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %283 = arith.xori %282, %252 : tensor<32xi32> loc(#loc184) + %284 = arith.xori %278, %279 : tensor<32xi32> loc(#loc185) + %285 = arith.cmpi ne, %283, %cst_0 : tensor<32xi32> loc(#loc186) + %286 = arith.select %285, %284, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %287 = arith.xori %280, %286 : tensor<32xi32> loc(#loc187) + %288 = arith.xori %276, %277 : tensor<32xi32> loc(#loc188) + %289 = arith.select %285, %288, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %290 = arith.xori %249, %289 : tensor<32xi32> loc(#loc190) + %291 = tt.bitcast %287 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %292 = tt.reshape %291 : tensor<32xf32> -> tensor<4x2x4xf32> loc(#loc149) + %293 = tt.reshape %290 : tensor<32xi32> -> tensor<4x2x4xi32> loc(#loc150) + %294 = arith.mulf %292, %142 : tensor<4x2x4xf32> loc(#loc152) + %295 = "tt.reduce"(%294) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<4x2x4xf32>) -> tensor<4x4xf32> loc(#loc192) + %296 = tt.expand_dims %295 {axis = 1 : i32} : tensor<4x4xf32> -> tensor<4x1x4xf32> loc(#loc156) + %297 = tt.broadcast %296 : tensor<4x1x4xf32> -> tensor<4x2x4xf32> loc(#loc157) + %298 = arith.mulf %292, %147 : tensor<4x2x4xf32> loc(#loc158) + %299 = "tt.reduce"(%298) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<4x2x4xf32>) -> tensor<4x4xf32> loc(#loc195) + %300 = tt.expand_dims %299 {axis = 1 : i32} : tensor<4x4xf32> -> tensor<4x1x4xf32> loc(#loc162) + %301 = tt.broadcast %300 : tensor<4x1x4xf32> -> tensor<4x2x4xf32> loc(#loc163) + %302 = tt.reshape %297 : tensor<4x2x4xf32> -> tensor<32xf32> loc(#loc164) + %303 = tt.reshape %301 : tensor<4x2x4xf32> -> tensor<32xf32> loc(#loc165) + %304 = arith.muli %293, %154 : tensor<4x2x4xi32> loc(#loc166) + %305 = "tt.reduce"(%304) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<4x2x4xi32>) -> tensor<4x4xi32> loc(#loc198) + %306 = tt.expand_dims %305 {axis = 1 : i32} : tensor<4x4xi32> -> tensor<4x1x4xi32> loc(#loc170) + %307 = tt.broadcast %306 : tensor<4x1x4xi32> -> tensor<4x2x4xi32> loc(#loc171) + %308 = arith.muli %293, %61 : tensor<4x2x4xi32> loc(#loc172) + %309 = "tt.reduce"(%308) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<4x2x4xi32>) -> tensor<4x4xi32> loc(#loc201) + %310 = tt.expand_dims %309 {axis = 1 : i32} : tensor<4x4xi32> -> tensor<4x1x4xi32> loc(#loc176) + %311 = tt.broadcast %310 : tensor<4x1x4xi32> -> tensor<4x2x4xi32> loc(#loc177) + %312 = tt.reshape %307 : tensor<4x2x4xi32> -> tensor<32xi32> loc(#loc178) + %313 = tt.reshape %311 : tensor<4x2x4xi32> -> tensor<32xi32> loc(#loc179) + %314 = tt.bitcast %302 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %315 = tt.bitcast %303 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %316 = tt.bitcast %291 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %317 = arith.cmpf ogt, %302, %303 : tensor<32xf32> loc(#loc183) + %318 = arith.extui %317 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %319 = arith.xori %318, %252 : tensor<32xi32> loc(#loc184) + %320 = arith.xori %314, %315 : tensor<32xi32> loc(#loc185) + %321 = arith.cmpi ne, %319, %cst_0 : tensor<32xi32> loc(#loc186) + %322 = arith.select %321, %320, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %323 = arith.xori %316, %322 : tensor<32xi32> loc(#loc187) + %324 = arith.xori %312, %313 : tensor<32xi32> loc(#loc188) + %325 = arith.select %321, %324, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %326 = arith.xori %290, %325 : tensor<32xi32> loc(#loc190) + %327 = tt.bitcast %323 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %328 = tt.reshape %327 : tensor<32xf32> -> tensor<8x2x2xf32> loc(#loc149) + %329 = tt.reshape %326 : tensor<32xi32> -> tensor<8x2x2xi32> loc(#loc150) + %330 = arith.mulf %328, %65 : tensor<8x2x2xf32> loc(#loc152) + %331 = "tt.reduce"(%330) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<8x2x2xf32>) -> tensor<8x2xf32> loc(#loc192) + %332 = tt.expand_dims %331 {axis = 1 : i32} : tensor<8x2xf32> -> tensor<8x1x2xf32> loc(#loc156) + %333 = tt.broadcast %332 : tensor<8x1x2xf32> -> tensor<8x2x2xf32> loc(#loc157) + %334 = arith.mulf %328, %70 : tensor<8x2x2xf32> loc(#loc158) + %335 = "tt.reduce"(%334) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<8x2x2xf32>) -> tensor<8x2xf32> loc(#loc195) + %336 = tt.expand_dims %335 {axis = 1 : i32} : tensor<8x2xf32> -> tensor<8x1x2xf32> loc(#loc162) + %337 = tt.broadcast %336 : tensor<8x1x2xf32> -> tensor<8x2x2xf32> loc(#loc163) + %338 = tt.reshape %333 : tensor<8x2x2xf32> -> tensor<32xf32> loc(#loc164) + %339 = tt.reshape %337 : tensor<8x2x2xf32> -> tensor<32xf32> loc(#loc165) + %340 = arith.muli %329, %77 : tensor<8x2x2xi32> loc(#loc166) + %341 = "tt.reduce"(%340) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<8x2x2xi32>) -> tensor<8x2xi32> loc(#loc198) + %342 = tt.expand_dims %341 {axis = 1 : i32} : tensor<8x2xi32> -> tensor<8x1x2xi32> loc(#loc170) + %343 = tt.broadcast %342 : tensor<8x1x2xi32> -> tensor<8x2x2xi32> loc(#loc171) + %344 = arith.muli %329, %16 : tensor<8x2x2xi32> loc(#loc172) + %345 = "tt.reduce"(%344) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<8x2x2xi32>) -> tensor<8x2xi32> loc(#loc201) + %346 = tt.expand_dims %345 {axis = 1 : i32} : tensor<8x2xi32> -> tensor<8x1x2xi32> loc(#loc176) + %347 = tt.broadcast %346 : tensor<8x1x2xi32> -> tensor<8x2x2xi32> loc(#loc177) + %348 = tt.reshape %343 : tensor<8x2x2xi32> -> tensor<32xi32> loc(#loc178) + %349 = tt.reshape %347 : tensor<8x2x2xi32> -> tensor<32xi32> loc(#loc179) + %350 = tt.bitcast %338 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %351 = tt.bitcast %339 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %352 = tt.bitcast %327 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %353 = arith.cmpf ogt, %338, %339 : tensor<32xf32> loc(#loc183) + %354 = arith.extui %353 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %355 = arith.xori %354, %252 : tensor<32xi32> loc(#loc184) + %356 = arith.xori %350, %351 : tensor<32xi32> loc(#loc185) + %357 = arith.cmpi ne, %355, %cst_0 : tensor<32xi32> loc(#loc186) + %358 = arith.select %357, %356, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %359 = arith.xori %352, %358 : tensor<32xi32> loc(#loc187) + %360 = arith.xori %348, %349 : tensor<32xi32> loc(#loc188) + %361 = arith.select %357, %360, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %362 = arith.xori %326, %361 : tensor<32xi32> loc(#loc190) + %363 = tt.bitcast %359 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %364 = tt.reshape %363 : tensor<32xf32> -> tensor<16x2x1xf32> loc(#loc149) + %365 = tt.reshape %362 : tensor<32xi32> -> tensor<16x2x1xi32> loc(#loc150) + %366 = arith.mulf %364, %22 : tensor<16x2x1xf32> loc(#loc152) + %367 = "tt.reduce"(%366) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc192) + %368 = tt.expand_dims %367 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc156) + %369 = tt.broadcast %368 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc157) + %370 = arith.mulf %364, %28 : tensor<16x2x1xf32> loc(#loc158) + %371 = "tt.reduce"(%370) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc195) + %372 = tt.expand_dims %371 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc162) + %373 = tt.broadcast %372 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc163) + %374 = tt.reshape %369 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc164) + %375 = tt.reshape %373 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc165) + %376 = arith.muli %365, %35 : tensor<16x2x1xi32> loc(#loc166) + %377 = "tt.reduce"(%376) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc198) + %378 = tt.expand_dims %377 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc170) + %379 = tt.broadcast %378 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc171) + %380 = arith.muli %365, %40 : tensor<16x2x1xi32> loc(#loc172) + %381 = "tt.reduce"(%380) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc201) + %382 = tt.expand_dims %381 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc176) + %383 = tt.broadcast %382 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc177) + %384 = tt.reshape %379 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc178) + %385 = tt.reshape %383 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc179) + %386 = tt.bitcast %374 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %387 = tt.bitcast %375 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %388 = tt.bitcast %363 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %389 = arith.cmpf ogt, %374, %375 : tensor<32xf32> loc(#loc183) + %390 = arith.extui %389 : tensor<32xi1> to tensor<32xi32> loc(#loc184) + %391 = arith.xori %390, %252 : tensor<32xi32> loc(#loc184) + %392 = arith.xori %386, %387 : tensor<32xi32> loc(#loc185) + %393 = arith.cmpi ne, %391, %cst_0 : tensor<32xi32> loc(#loc186) + %394 = arith.select %393, %392, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %395 = arith.xori %388, %394 : tensor<32xi32> loc(#loc187) + %396 = arith.xori %384, %385 : tensor<32xi32> loc(#loc188) + %397 = arith.select %393, %396, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %398 = arith.xori %362, %397 : tensor<32xi32> loc(#loc190) + %399 = tt.bitcast %395 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %400 = tt.reshape %399 : tensor<32xf32> -> tensor<1x2x16xf32> loc(#loc149) + %401 = tt.reshape %398 : tensor<32xi32> -> tensor<1x2x16xi32> loc(#loc150) + %402 = tt.broadcast %21 : tensor<1x2x1xf32> -> tensor<1x2x16xf32> loc(#loc152) + %403 = arith.mulf %400, %402 : tensor<1x2x16xf32> loc(#loc152) + %404 = "tt.reduce"(%403) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<1x2x16xf32>) -> tensor<1x16xf32> loc(#loc192) + %405 = tt.expand_dims %404 {axis = 1 : i32} : tensor<1x16xf32> -> tensor<1x1x16xf32> loc(#loc156) + %406 = tt.broadcast %405 : tensor<1x1x16xf32> -> tensor<1x2x16xf32> loc(#loc157) + %407 = tt.broadcast %27 : tensor<1x2x1xf32> -> tensor<1x2x16xf32> loc(#loc158) + %408 = arith.mulf %400, %407 : tensor<1x2x16xf32> loc(#loc158) + %409 = "tt.reduce"(%408) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<1x2x16xf32>) -> tensor<1x16xf32> loc(#loc195) + %410 = tt.expand_dims %409 {axis = 1 : i32} : tensor<1x16xf32> -> tensor<1x1x16xf32> loc(#loc162) + %411 = tt.broadcast %410 : tensor<1x1x16xf32> -> tensor<1x2x16xf32> loc(#loc163) + %412 = tt.reshape %406 : tensor<1x2x16xf32> -> tensor<32xf32> loc(#loc164) + %413 = tt.reshape %411 : tensor<1x2x16xf32> -> tensor<32xf32> loc(#loc165) + %414 = tt.broadcast %20 : tensor<1x2x1xi32> -> tensor<1x2x16xi32> loc(#loc166) + %415 = arith.muli %401, %414 : tensor<1x2x16xi32> loc(#loc166) + %416 = "tt.reduce"(%415) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<1x2x16xi32>) -> tensor<1x16xi32> loc(#loc198) + %417 = tt.expand_dims %416 {axis = 1 : i32} : tensor<1x16xi32> -> tensor<1x1x16xi32> loc(#loc170) + %418 = tt.broadcast %417 : tensor<1x1x16xi32> -> tensor<1x2x16xi32> loc(#loc171) + %419 = arith.muli %401, %251 : tensor<1x2x16xi32> loc(#loc172) + %420 = "tt.reduce"(%419) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<1x2x16xi32>) -> tensor<1x16xi32> loc(#loc201) + %421 = tt.expand_dims %420 {axis = 1 : i32} : tensor<1x16xi32> -> tensor<1x1x16xi32> loc(#loc176) + %422 = tt.broadcast %421 : tensor<1x1x16xi32> -> tensor<1x2x16xi32> loc(#loc177) + %423 = tt.reshape %418 : tensor<1x2x16xi32> -> tensor<32xi32> loc(#loc178) + %424 = tt.reshape %422 : tensor<1x2x16xi32> -> tensor<32xi32> loc(#loc179) + %425 = tt.bitcast %412 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %426 = tt.bitcast %413 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %427 = tt.bitcast %399 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %428 = arith.cmpf ogt, %412, %413 : tensor<32xf32> loc(#loc183) + %429 = arith.xori %425, %426 : tensor<32xi32> loc(#loc185) + %430 = arith.select %428, %429, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %431 = arith.xori %427, %430 : tensor<32xi32> loc(#loc187) + %432 = arith.xori %423, %424 : tensor<32xi32> loc(#loc188) + %433 = arith.select %428, %432, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %434 = arith.xori %398, %433 : tensor<32xi32> loc(#loc190) + %435 = tt.bitcast %431 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %436 = tt.reshape %435 : tensor<32xf32> -> tensor<2x2x8xf32> loc(#loc149) + %437 = tt.reshape %434 : tensor<32xi32> -> tensor<2x2x8xi32> loc(#loc150) + %438 = arith.mulf %436, %255 : tensor<2x2x8xf32> loc(#loc152) + %439 = "tt.reduce"(%438) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<2x2x8xf32>) -> tensor<2x8xf32> loc(#loc192) + %440 = tt.expand_dims %439 {axis = 1 : i32} : tensor<2x8xf32> -> tensor<2x1x8xf32> loc(#loc156) + %441 = tt.broadcast %440 : tensor<2x1x8xf32> -> tensor<2x2x8xf32> loc(#loc157) + %442 = arith.mulf %436, %260 : tensor<2x2x8xf32> loc(#loc158) + %443 = "tt.reduce"(%442) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<2x2x8xf32>) -> tensor<2x8xf32> loc(#loc195) + %444 = tt.expand_dims %443 {axis = 1 : i32} : tensor<2x8xf32> -> tensor<2x1x8xf32> loc(#loc162) + %445 = tt.broadcast %444 : tensor<2x1x8xf32> -> tensor<2x2x8xf32> loc(#loc163) + %446 = tt.reshape %441 : tensor<2x2x8xf32> -> tensor<32xf32> loc(#loc164) + %447 = tt.reshape %445 : tensor<2x2x8xf32> -> tensor<32xf32> loc(#loc165) + %448 = arith.muli %437, %267 : tensor<2x2x8xi32> loc(#loc166) + %449 = "tt.reduce"(%448) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<2x2x8xi32>) -> tensor<2x8xi32> loc(#loc198) + %450 = tt.expand_dims %449 {axis = 1 : i32} : tensor<2x8xi32> -> tensor<2x1x8xi32> loc(#loc170) + %451 = tt.broadcast %450 : tensor<2x1x8xi32> -> tensor<2x2x8xi32> loc(#loc171) + %452 = arith.muli %437, %138 : tensor<2x2x8xi32> loc(#loc172) + %453 = "tt.reduce"(%452) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<2x2x8xi32>) -> tensor<2x8xi32> loc(#loc201) + %454 = tt.expand_dims %453 {axis = 1 : i32} : tensor<2x8xi32> -> tensor<2x1x8xi32> loc(#loc176) + %455 = tt.broadcast %454 : tensor<2x1x8xi32> -> tensor<2x2x8xi32> loc(#loc177) + %456 = tt.reshape %451 : tensor<2x2x8xi32> -> tensor<32xi32> loc(#loc178) + %457 = tt.reshape %455 : tensor<2x2x8xi32> -> tensor<32xi32> loc(#loc179) + %458 = tt.bitcast %446 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %459 = tt.bitcast %447 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %460 = tt.bitcast %435 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %461 = arith.cmpf ogt, %446, %447 : tensor<32xf32> loc(#loc183) + %462 = arith.xori %458, %459 : tensor<32xi32> loc(#loc185) + %463 = arith.select %461, %462, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %464 = arith.xori %460, %463 : tensor<32xi32> loc(#loc187) + %465 = arith.xori %456, %457 : tensor<32xi32> loc(#loc188) + %466 = arith.select %461, %465, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %467 = arith.xori %434, %466 : tensor<32xi32> loc(#loc190) + %468 = tt.bitcast %464 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %469 = tt.reshape %468 : tensor<32xf32> -> tensor<4x2x4xf32> loc(#loc149) + %470 = tt.reshape %467 : tensor<32xi32> -> tensor<4x2x4xi32> loc(#loc150) + %471 = arith.mulf %469, %142 : tensor<4x2x4xf32> loc(#loc152) + %472 = "tt.reduce"(%471) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<4x2x4xf32>) -> tensor<4x4xf32> loc(#loc192) + %473 = tt.expand_dims %472 {axis = 1 : i32} : tensor<4x4xf32> -> tensor<4x1x4xf32> loc(#loc156) + %474 = tt.broadcast %473 : tensor<4x1x4xf32> -> tensor<4x2x4xf32> loc(#loc157) + %475 = arith.mulf %469, %147 : tensor<4x2x4xf32> loc(#loc158) + %476 = "tt.reduce"(%475) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<4x2x4xf32>) -> tensor<4x4xf32> loc(#loc195) + %477 = tt.expand_dims %476 {axis = 1 : i32} : tensor<4x4xf32> -> tensor<4x1x4xf32> loc(#loc162) + %478 = tt.broadcast %477 : tensor<4x1x4xf32> -> tensor<4x2x4xf32> loc(#loc163) + %479 = tt.reshape %474 : tensor<4x2x4xf32> -> tensor<32xf32> loc(#loc164) + %480 = tt.reshape %478 : tensor<4x2x4xf32> -> tensor<32xf32> loc(#loc165) + %481 = arith.muli %470, %154 : tensor<4x2x4xi32> loc(#loc166) + %482 = "tt.reduce"(%481) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<4x2x4xi32>) -> tensor<4x4xi32> loc(#loc198) + %483 = tt.expand_dims %482 {axis = 1 : i32} : tensor<4x4xi32> -> tensor<4x1x4xi32> loc(#loc170) + %484 = tt.broadcast %483 : tensor<4x1x4xi32> -> tensor<4x2x4xi32> loc(#loc171) + %485 = arith.muli %470, %61 : tensor<4x2x4xi32> loc(#loc172) + %486 = "tt.reduce"(%485) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<4x2x4xi32>) -> tensor<4x4xi32> loc(#loc201) + %487 = tt.expand_dims %486 {axis = 1 : i32} : tensor<4x4xi32> -> tensor<4x1x4xi32> loc(#loc176) + %488 = tt.broadcast %487 : tensor<4x1x4xi32> -> tensor<4x2x4xi32> loc(#loc177) + %489 = tt.reshape %484 : tensor<4x2x4xi32> -> tensor<32xi32> loc(#loc178) + %490 = tt.reshape %488 : tensor<4x2x4xi32> -> tensor<32xi32> loc(#loc179) + %491 = tt.bitcast %479 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %492 = tt.bitcast %480 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %493 = tt.bitcast %468 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %494 = arith.cmpf ogt, %479, %480 : tensor<32xf32> loc(#loc183) + %495 = arith.xori %491, %492 : tensor<32xi32> loc(#loc185) + %496 = arith.select %494, %495, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %497 = arith.xori %493, %496 : tensor<32xi32> loc(#loc187) + %498 = arith.xori %489, %490 : tensor<32xi32> loc(#loc188) + %499 = arith.select %494, %498, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %500 = arith.xori %467, %499 : tensor<32xi32> loc(#loc190) + %501 = tt.bitcast %497 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %502 = tt.reshape %501 : tensor<32xf32> -> tensor<8x2x2xf32> loc(#loc149) + %503 = tt.reshape %500 : tensor<32xi32> -> tensor<8x2x2xi32> loc(#loc150) + %504 = arith.mulf %502, %65 : tensor<8x2x2xf32> loc(#loc152) + %505 = "tt.reduce"(%504) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<8x2x2xf32>) -> tensor<8x2xf32> loc(#loc192) + %506 = tt.expand_dims %505 {axis = 1 : i32} : tensor<8x2xf32> -> tensor<8x1x2xf32> loc(#loc156) + %507 = tt.broadcast %506 : tensor<8x1x2xf32> -> tensor<8x2x2xf32> loc(#loc157) + %508 = arith.mulf %502, %70 : tensor<8x2x2xf32> loc(#loc158) + %509 = "tt.reduce"(%508) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<8x2x2xf32>) -> tensor<8x2xf32> loc(#loc195) + %510 = tt.expand_dims %509 {axis = 1 : i32} : tensor<8x2xf32> -> tensor<8x1x2xf32> loc(#loc162) + %511 = tt.broadcast %510 : tensor<8x1x2xf32> -> tensor<8x2x2xf32> loc(#loc163) + %512 = tt.reshape %507 : tensor<8x2x2xf32> -> tensor<32xf32> loc(#loc164) + %513 = tt.reshape %511 : tensor<8x2x2xf32> -> tensor<32xf32> loc(#loc165) + %514 = arith.muli %503, %77 : tensor<8x2x2xi32> loc(#loc166) + %515 = "tt.reduce"(%514) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<8x2x2xi32>) -> tensor<8x2xi32> loc(#loc198) + %516 = tt.expand_dims %515 {axis = 1 : i32} : tensor<8x2xi32> -> tensor<8x1x2xi32> loc(#loc170) + %517 = tt.broadcast %516 : tensor<8x1x2xi32> -> tensor<8x2x2xi32> loc(#loc171) + %518 = arith.muli %503, %16 : tensor<8x2x2xi32> loc(#loc172) + %519 = "tt.reduce"(%518) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<8x2x2xi32>) -> tensor<8x2xi32> loc(#loc201) + %520 = tt.expand_dims %519 {axis = 1 : i32} : tensor<8x2xi32> -> tensor<8x1x2xi32> loc(#loc176) + %521 = tt.broadcast %520 : tensor<8x1x2xi32> -> tensor<8x2x2xi32> loc(#loc177) + %522 = tt.reshape %517 : tensor<8x2x2xi32> -> tensor<32xi32> loc(#loc178) + %523 = tt.reshape %521 : tensor<8x2x2xi32> -> tensor<32xi32> loc(#loc179) + %524 = tt.bitcast %512 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %525 = tt.bitcast %513 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %526 = tt.bitcast %501 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %527 = arith.cmpf ogt, %512, %513 : tensor<32xf32> loc(#loc183) + %528 = arith.xori %524, %525 : tensor<32xi32> loc(#loc185) + %529 = arith.select %527, %528, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %530 = arith.xori %526, %529 : tensor<32xi32> loc(#loc187) + %531 = arith.xori %522, %523 : tensor<32xi32> loc(#loc188) + %532 = arith.select %527, %531, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %533 = arith.xori %500, %532 : tensor<32xi32> loc(#loc190) + %534 = tt.bitcast %530 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + %535 = tt.reshape %534 : tensor<32xf32> -> tensor<16x2x1xf32> loc(#loc149) + %536 = tt.reshape %533 : tensor<32xi32> -> tensor<16x2x1xi32> loc(#loc150) + %537 = arith.mulf %535, %22 : tensor<16x2x1xf32> loc(#loc152) + %538 = "tt.reduce"(%537) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc154 at #loc13)), %arg4: f32 loc(callsite(#loc154 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc204) + tt.reduce.return %569 : f32 loc(#loc192) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc192) + %539 = tt.expand_dims %538 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc156) + %540 = tt.broadcast %539 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc157) + %541 = arith.mulf %535, %28 : tensor<16x2x1xf32> loc(#loc158) + %542 = "tt.reduce"(%541) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc160 at #loc13)), %arg4: f32 loc(callsite(#loc160 at #loc13))): + %569 = arith.addf %arg3, %arg4 : f32 loc(#loc205) + tt.reduce.return %569 : f32 loc(#loc195) + }) : (tensor<16x2x1xf32>) -> tensor<16x1xf32> loc(#loc195) + %543 = tt.expand_dims %542 {axis = 1 : i32} : tensor<16x1xf32> -> tensor<16x1x1xf32> loc(#loc162) + %544 = tt.broadcast %543 : tensor<16x1x1xf32> -> tensor<16x2x1xf32> loc(#loc163) + %545 = tt.reshape %540 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc164) + %546 = tt.reshape %544 : tensor<16x2x1xf32> -> tensor<32xf32> loc(#loc165) + %547 = arith.muli %536, %35 : tensor<16x2x1xi32> loc(#loc166) + %548 = "tt.reduce"(%547) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc168 at #loc13)), %arg4: i32 loc(callsite(#loc168 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc206) + tt.reduce.return %569 : i32 loc(#loc198) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc198) + %549 = tt.expand_dims %548 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc170) + %550 = tt.broadcast %549 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc171) + %551 = arith.muli %536, %40 : tensor<16x2x1xi32> loc(#loc172) + %552 = "tt.reduce"(%551) <{axis = 1 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc174 at #loc13)), %arg4: i32 loc(callsite(#loc174 at #loc13))): + %569 = arith.addi %arg3, %arg4 : i32 loc(#loc207) + tt.reduce.return %569 : i32 loc(#loc201) + }) : (tensor<16x2x1xi32>) -> tensor<16x1xi32> loc(#loc201) + %553 = tt.expand_dims %552 {axis = 1 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> loc(#loc176) + %554 = tt.broadcast %553 : tensor<16x1x1xi32> -> tensor<16x2x1xi32> loc(#loc177) + %555 = tt.reshape %550 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc178) + %556 = tt.reshape %554 : tensor<16x2x1xi32> -> tensor<32xi32> loc(#loc179) + %557 = tt.bitcast %545 : tensor<32xf32> -> tensor<32xi32> loc(#loc180) + %558 = tt.bitcast %546 : tensor<32xf32> -> tensor<32xi32> loc(#loc181) + %559 = tt.bitcast %534 : tensor<32xf32> -> tensor<32xi32> loc(#loc182) + %560 = arith.cmpf ogt, %545, %546 : tensor<32xf32> loc(#loc183) + %561 = arith.xori %557, %558 : tensor<32xi32> loc(#loc185) + %562 = arith.select %560, %561, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc186) + %563 = arith.xori %559, %562 : tensor<32xi32> loc(#loc187) + %564 = arith.xori %555, %556 : tensor<32xi32> loc(#loc188) + %565 = arith.select %560, %564, %cst_0 : tensor<32xi1>, tensor<32xi32> loc(#loc189) + %566 = arith.xori %533, %565 : tensor<32xi32> loc(#loc190) + %567 = tt.bitcast %563 : tensor<32xi32> -> tensor<32xf32> loc(#loc191) + tt.store %9, %567, %1 : tensor<32x!tt.ptr> loc(#loc55) + %568 = arith.extsi %566 : tensor<32xi32> to tensor<32xi64> loc(#loc56) + tt.store %11, %568, %1 : tensor<32x!tt.ptr> loc(#loc56) + tt.return loc(#loc57) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":181:24) +#loc3 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":182:18) +#loc4 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":183:27) +#loc5 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":183:32) +#loc6 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":183:36) +#loc7 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":184:14) +#loc8 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":185:15) +#loc9 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":186:21) +#loc10 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":190:25) +#loc11 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":151:45) +#loc14 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":151:48) +#loc15 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":151:64) +#loc16 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":151:72) +#loc17 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":75:24) +#loc19 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":76:30) +#loc20 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":80:45) +#loc21 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":80:41) +#loc22 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":267:36) +#loc24 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":256:15) +#loc25 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":80:55) +#loc26 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":80:68) +#loc27 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":81:41) +#loc29 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":81:50) +#loc30 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":81:63) +#loc31 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":82:30) +#loc32 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":83:32) +#loc33 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":85:49) +#loc35 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":85:63) +#loc36 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":85:76) +#loc37 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":88:49) +#loc39 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":88:58) +#loc40 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":88:71) +#loc41 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":91:38) +#loc42 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":92:40) +#loc43 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":106:20) +#loc44 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":107:22) +#loc45 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":108:14) +#loc46 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":110:19) +#loc47 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":110:28) +#loc48 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":111:40) +#loc49 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":111:48) +#loc50 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":111:15) +#loc51 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":127:52) +#loc52 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":127:64) +#loc53 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":127:23) +#loc54 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":129:18) +#loc55 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":200:22) +#loc56 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":201:28) +#loc57 = loc("/home/zhengyang/FlagTree/flagtree/python/test/ops/sort/sort_ascend.py":201:4) +#loc58 = loc(callsite(#loc11 at #loc12)) +#loc59 = loc(callsite(#loc14 at #loc12)) +#loc60 = loc(callsite(#loc15 at #loc12)) +#loc61 = loc(callsite(#loc16 at #loc12)) +#loc62 = loc(callsite(#loc17 at #loc18)) +#loc63 = loc(callsite(#loc19 at #loc18)) +#loc64 = loc(callsite(#loc20 at #loc18)) +#loc65 = loc(callsite(#loc21 at #loc18)) +#loc66 = loc(callsite(#loc22 at #loc23)) +#loc68 = loc(callsite(#loc24 at #loc22)) +#loc69 = loc(callsite(#loc25 at #loc18)) +#loc70 = loc(callsite(#loc26 at #loc18)) +#loc71 = loc(callsite(#loc27 at #loc18)) +#loc72 = loc(callsite(#loc22 at #loc28)) +#loc74 = loc(callsite(#loc29 at #loc18)) +#loc75 = loc(callsite(#loc30 at #loc18)) +#loc76 = loc(callsite(#loc31 at #loc18)) +#loc77 = loc(callsite(#loc32 at #loc18)) +#loc78 = loc(callsite(#loc33 at #loc18)) +#loc79 = loc(callsite(#loc22 at #loc34)) +#loc81 = loc(callsite(#loc35 at #loc18)) +#loc82 = loc(callsite(#loc36 at #loc18)) +#loc83 = loc(callsite(#loc37 at #loc18)) +#loc84 = loc(callsite(#loc22 at #loc38)) +#loc86 = loc(callsite(#loc39 at #loc18)) +#loc87 = loc(callsite(#loc40 at #loc18)) +#loc88 = loc(callsite(#loc41 at #loc18)) +#loc89 = loc(callsite(#loc42 at #loc18)) +#loc90 = loc(callsite(#loc43 at #loc18)) +#loc91 = loc(callsite(#loc44 at #loc18)) +#loc92 = loc(callsite(#loc45 at #loc18)) +#loc93 = loc(callsite(#loc46 at #loc18)) +#loc94 = loc(callsite(#loc47 at #loc18)) +#loc95 = loc(callsite(#loc48 at #loc18)) +#loc96 = loc(callsite(#loc49 at #loc18)) +#loc97 = loc(callsite(#loc50 at #loc18)) +#loc98 = loc(callsite(#loc51 at #loc18)) +#loc99 = loc(callsite(#loc52 at #loc18)) +#loc100 = loc(callsite(#loc53 at #loc18)) +#loc101 = loc(callsite(#loc54 at #loc18)) +#loc102 = loc(callsite(#loc58 at #loc13)) +#loc103 = loc(callsite(#loc59 at #loc13)) +#loc104 = loc(callsite(#loc60 at #loc13)) +#loc105 = loc(callsite(#loc61 at #loc13)) +#loc106 = loc(callsite(#loc62 at #loc12)) +#loc107 = loc(callsite(#loc63 at #loc12)) +#loc108 = loc(callsite(#loc64 at #loc12)) +#loc109 = loc(callsite(#loc65 at #loc12)) +#loc110 = loc(callsite(#loc66 at #loc18)) +#loc112 = loc(callsite(#loc68 at #loc23)) +#loc113 = loc(callsite(#loc69 at #loc12)) +#loc114 = loc(callsite(#loc70 at #loc12)) +#loc115 = loc(callsite(#loc71 at #loc12)) +#loc116 = loc(callsite(#loc72 at #loc18)) +#loc118 = loc(callsite(#loc68 at #loc28)) +#loc119 = loc(callsite(#loc74 at #loc12)) +#loc120 = loc(callsite(#loc75 at #loc12)) +#loc121 = loc(callsite(#loc76 at #loc12)) +#loc122 = loc(callsite(#loc77 at #loc12)) +#loc123 = loc(callsite(#loc78 at #loc12)) +#loc124 = loc(callsite(#loc79 at #loc18)) +#loc126 = loc(callsite(#loc68 at #loc34)) +#loc127 = loc(callsite(#loc81 at #loc12)) +#loc128 = loc(callsite(#loc82 at #loc12)) +#loc129 = loc(callsite(#loc83 at #loc12)) +#loc130 = loc(callsite(#loc84 at #loc18)) +#loc132 = loc(callsite(#loc68 at #loc38)) +#loc133 = loc(callsite(#loc86 at #loc12)) +#loc134 = loc(callsite(#loc87 at #loc12)) +#loc135 = loc(callsite(#loc88 at #loc12)) +#loc136 = loc(callsite(#loc89 at #loc12)) +#loc137 = loc(callsite(#loc90 at #loc12)) +#loc138 = loc(callsite(#loc91 at #loc12)) +#loc139 = loc(callsite(#loc92 at #loc12)) +#loc140 = loc(callsite(#loc93 at #loc12)) +#loc141 = loc(callsite(#loc94 at #loc12)) +#loc142 = loc(callsite(#loc95 at #loc12)) +#loc143 = loc(callsite(#loc96 at #loc12)) +#loc144 = loc(callsite(#loc97 at #loc12)) +#loc145 = loc(callsite(#loc98 at #loc12)) +#loc146 = loc(callsite(#loc99 at #loc12)) +#loc147 = loc(callsite(#loc100 at #loc12)) +#loc148 = loc(callsite(#loc101 at #loc12)) +#loc149 = loc(callsite(#loc106 at #loc13)) +#loc150 = loc(callsite(#loc107 at #loc13)) +#loc151 = loc(callsite(#loc108 at #loc13)) +#loc152 = loc(callsite(#loc109 at #loc13)) +#loc153 = loc(callsite(#loc110 at #loc12)) +#loc155 = loc(callsite(#loc112 at #loc18)) +#loc156 = loc(callsite(#loc113 at #loc13)) +#loc157 = loc(callsite(#loc114 at #loc13)) +#loc158 = loc(callsite(#loc115 at #loc13)) +#loc159 = loc(callsite(#loc116 at #loc12)) +#loc161 = loc(callsite(#loc118 at #loc18)) +#loc162 = loc(callsite(#loc119 at #loc13)) +#loc163 = loc(callsite(#loc120 at #loc13)) +#loc164 = loc(callsite(#loc121 at #loc13)) +#loc165 = loc(callsite(#loc122 at #loc13)) +#loc166 = loc(callsite(#loc123 at #loc13)) +#loc167 = loc(callsite(#loc124 at #loc12)) +#loc169 = loc(callsite(#loc126 at #loc18)) +#loc170 = loc(callsite(#loc127 at #loc13)) +#loc171 = loc(callsite(#loc128 at #loc13)) +#loc172 = loc(callsite(#loc129 at #loc13)) +#loc173 = loc(callsite(#loc130 at #loc12)) +#loc175 = loc(callsite(#loc132 at #loc18)) +#loc176 = loc(callsite(#loc133 at #loc13)) +#loc177 = loc(callsite(#loc134 at #loc13)) +#loc178 = loc(callsite(#loc135 at #loc13)) +#loc179 = loc(callsite(#loc136 at #loc13)) +#loc180 = loc(callsite(#loc137 at #loc13)) +#loc181 = loc(callsite(#loc138 at #loc13)) +#loc182 = loc(callsite(#loc139 at #loc13)) +#loc183 = loc(callsite(#loc140 at #loc13)) +#loc184 = loc(callsite(#loc141 at #loc13)) +#loc185 = loc(callsite(#loc142 at #loc13)) +#loc186 = loc(callsite(#loc143 at #loc13)) +#loc187 = loc(callsite(#loc144 at #loc13)) +#loc188 = loc(callsite(#loc145 at #loc13)) +#loc189 = loc(callsite(#loc146 at #loc13)) +#loc190 = loc(callsite(#loc147 at #loc13)) +#loc191 = loc(callsite(#loc148 at #loc13)) +#loc192 = loc(callsite(#loc153 at #loc13)) +#loc194 = loc(callsite(#loc155 at #loc12)) +#loc195 = loc(callsite(#loc159 at #loc13)) +#loc197 = loc(callsite(#loc161 at #loc12)) +#loc198 = loc(callsite(#loc167 at #loc13)) +#loc200 = loc(callsite(#loc169 at #loc12)) +#loc201 = loc(callsite(#loc173 at #loc13)) +#loc203 = loc(callsite(#loc175 at #loc12)) +#loc204 = loc(callsite(#loc194 at #loc13)) +#loc205 = loc(callsite(#loc197 at #loc13)) +#loc206 = loc(callsite(#loc200 at #loc13)) +#loc207 = loc(callsite(#loc203 at #loc13)) diff --git a/python/test/ops/sum_dim/sum_dim.py b/python/test/ops/sum_dim/sum_dim.py new file mode 100644 index 000000000..a19c4dfec --- /dev/null +++ b/python/test/ops/sum_dim/sum_dim.py @@ -0,0 +1,106 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def sum_dim_kernel( + inp, + out, + M, + N, + BLOCK_M: tl.constexpr = 8, + BLOCK_N: tl.constexpr = 256, +): + if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( + inp.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = inp.dtype.element_ty + + # 1. prepare offset + pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + inp = inp + pid * N + out = out + pid + row_mask = pid < M + _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) + + # 2. for + for off in range(0, N, BLOCK_N): + cols = off + tl.arange(0, BLOCK_N)[None, :] + col_mask = cols < N + mask = row_mask and col_mask + + a = tl.load(inp + cols, mask, other=0).to(cdtype) + _sum += a + + # 3. store + sum = tl.sum(_sum, axis=1)[:, None] + tl.store(out, sum, row_mask) + + +def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): + if dtype is None: + dtype = inp.dtype + if dtype is torch.bool: + dtype = torch.int64 + + if dim == []: + pass + + shape = list(inp.shape) + dim = [d % inp.ndim for d in dim] + + if len(dim) == 1 or len(dim) > 1: + N = 1 + for i in dim: + N *= shape[i] + shape[i] = 1 + M = inp.numel() // N + if out is None: + out = torch.empty(shape, dtype=dtype, device=inp.device) + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + with torch_device_fn.device(inp.device): + sum_dim_kernel[grid](inp, out, M, N) + if not keepdim: + out = out.squeeze(dim=dim) + return out + + +def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): + return sum_dim_comm(inp, dim, keepdim, dtype=dtype) + + +if __name__ == "__main__": + # param + shape = (1, 32) + dim = [1] + keepdim = True + + # inp + inp = torch.randn(shape, dtype=torch.float32, device=device) + ref_inp = inp.cpu() + + # op + ref_out = torch.sum(ref_inp, dim=dim, keepdim=keepdim) + res_out = sum_dim(inp, dim=dim, keepdim=keepdim) + + # check + res_out = res_out.cpu() + print( + f"The maximum difference out value between torch and triton is " + f"{torch.max(torch.abs(ref_out - res_out))}" + ) + assert torch.allclose(res_out, ref_out), (res_out, ref_out) diff --git a/python/test/ops/sum_dim/triton-ascend/sum_dim_kernel.ttadapter b/python/test/ops/sum_dim/triton-ascend/sum_dim_kernel.ttadapter new file mode 100644 index 000000000..0381344a0 --- /dev/null +++ b/python/test/ops/sum_dim/triton-ascend/sum_dim_kernel.ttadapter @@ -0,0 +1,63 @@ +module { + func.func @sum_dim_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8x256xf32>) -> tensor<8x256xf32> + %2 = arith.muli %arg8, %c8_i32 : i32 + %3 = arith.index_cast %2 : i32 to index + %4 = arith.index_cast %arg4 : i32 to index + %5 = arith.muli %3, %4 : index + %reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%3], sizes: [8, 1], strides: [1, 1] : memref to memref<8x1xf32, strided<[1, 1], offset: ?>> + %6 = scf.for %arg11 = %c0_i32 to %arg4 step %c256_i32 iter_args(%arg12 = %1) -> (tensor<8x256xf32>) : i32 { + %13 = arith.index_cast %arg11 : i32 to index + %14 = arith.addi %5, %13 : index + %reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [%14], sizes: [8, 256], strides: [%4, 1] : memref to memref<8x256xf32, strided<[?, 1], offset: ?>> + %alloc = memref.alloc() : memref<8x256xf32> + %15 = arith.addi %3, %c8 : index + %16 = arith.maxsi %3, %c1 : index + %17 = arith.minsi %15, %16 : index + %18 = arith.subi %17, %3 : index + %19 = arith.addi %13, %c256 : index + %20 = arith.maxsi %13, %4 : index + %21 = arith.minsi %19, %20 : index + %22 = arith.subi %21, %13 : index + %23 = arith.minsi %18, %c8 : index + %24 = arith.minsi %22, %c256 : index + %25 = arith.cmpi slt, %23, %c8 : index + %26 = arith.cmpi slt, %24, %c256 : index + %27 = arith.ori %25, %26 : i1 + scf.if %27 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<8x256xf32>) + } + %subview_1 = memref.subview %reinterpret_cast_0[0, 0] [%23, %24] [1, 1] : memref<8x256xf32, strided<[?, 1], offset: ?>> to memref> + %subview_2 = memref.subview %alloc[0, 0] [%23, %24] [1, 1] : memref<8x256xf32> to memref> + memref.copy %subview_1, %subview_2 : memref> to memref> + %28 = bufferization.to_tensor %alloc restrict writable : memref<8x256xf32> + %29 = arith.addf %arg12, %28 : tensor<8x256xf32> + scf.yield %29 : tensor<8x256xf32> + } + %7 = tensor.empty() : tensor<8xf32> + %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8xf32>) -> tensor<8xf32> + %reduced = linalg.reduce ins(%6 : tensor<8x256xf32>) outs(%8 : tensor<8xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %13 = arith.addf %in, %init : f32 + linalg.yield %13 : f32 + } + %expanded = tensor.expand_shape %reduced [[0, 1]] output_shape [8, 1] : tensor<8xf32> into tensor<8x1xf32> + %9 = arith.addi %3, %c8 : index + %10 = arith.maxsi %3, %c1 : index + %11 = arith.minsi %9, %10 : index + %12 = arith.subi %11, %3 : index + %extracted_slice = tensor.extract_slice %expanded[0, 0] [%12, 1] [1, 1] : tensor<8x1xf32> to tensor + %subview = memref.subview %reinterpret_cast[0, 0] [%12, 1] [1, 1] : memref<8x1xf32, strided<[1, 1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/sum_dim/triton-ascend/sum_dim_kernel.ttir b/python/test/ops/sum_dim/triton-ascend/sum_dim_kernel.ttir new file mode 100644 index 000000000..317719d9a --- /dev/null +++ b/python/test/ops/sum_dim/triton-ascend/sum_dim_kernel.ttir @@ -0,0 +1,78 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":17:0) +#loc1 = loc(unknown) +#loc22 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":49:17) +#loc28 = loc(callsite(#loc1 at #loc22)) +module { + tt.func public @sum_dim_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":17:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":17:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":17:0)) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<8x256xf32> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_0 = arith.constant dense<1> : tensor<8x1xi32> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c8_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> loc(#loc4) + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> loc(#loc5) + %4 = tt.splat %1 : i32 -> tensor<8x1xi32> loc(#loc6) + %5 = arith.addi %4, %3 : tensor<8x1xi32> loc(#loc6) + %6 = tt.splat %arg2 : i32 -> tensor<8x1xi32> loc(#loc7) + %7 = arith.muli %5, %6 : tensor<8x1xi32> loc(#loc7) + %8 = tt.splat %arg0 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc8) + %9 = tt.addptr %8, %7 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc8) + %10 = tt.splat %arg1 : !tt.ptr -> tensor<8x1x!tt.ptr> loc(#loc9) + %11 = tt.addptr %10, %5 : tensor<8x1x!tt.ptr>, tensor<8x1xi32> loc(#loc9) + %12 = arith.cmpi slt, %5, %cst_0 : tensor<8x1xi32> loc(#loc10) + %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc11) + %14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> loc(#loc12) + %15 = tt.splat %arg2 : i32 -> tensor<1x256xi32> loc(#loc13) + %16 = tt.broadcast %12 : tensor<8x1xi1> -> tensor<8x256xi1> loc(#loc14) + %17 = tt.broadcast %9 : tensor<8x1x!tt.ptr> -> tensor<8x256x!tt.ptr> loc(#loc15) + %18 = scf.for %arg3 = %c0_i32 to %arg2 step %c256_i32 iter_args(%arg4 = %cst) -> (tensor<8x256xf32>) : i32 { + %21 = tt.splat %arg3 : i32 -> tensor<1x256xi32> loc(#loc17) + %22 = arith.addi %21, %14 : tensor<1x256xi32> loc(#loc17) + %23 = arith.cmpi slt, %22, %15 : tensor<1x256xi32> loc(#loc13) + %24 = tt.broadcast %23 : tensor<1x256xi1> -> tensor<8x256xi1> loc(#loc14) + %25 = arith.andi %16, %24 : tensor<8x256xi1> loc(#loc14) + %26 = tt.broadcast %22 : tensor<1x256xi32> -> tensor<8x256xi32> loc(#loc15) + %27 = tt.addptr %17, %26 : tensor<8x256x!tt.ptr>, tensor<8x256xi32> loc(#loc15) + %28 = tt.load %27, %25, %cst : tensor<8x256x!tt.ptr> loc(#loc18) + %29 = arith.addf %arg4, %28 : tensor<8x256xf32> loc(#loc19) + scf.yield %29 : tensor<8x256xf32> loc(#loc20) + } loc(#loc16) + %19 = "tt.reduce"(%18) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32 loc(callsite(#loc1 at #loc22)), %arg4: f32 loc(callsite(#loc1 at #loc22))): + %21 = arith.addf %arg3, %arg4 : f32 loc(#loc30) + tt.reduce.return %21 : f32 loc(#loc27) + }) : (tensor<8x256xf32>) -> tensor<8xf32> loc(#loc27) + %20 = tt.expand_dims %19 {axis = 1 : i32} : tensor<8xf32> -> tensor<8x1xf32> loc(#loc24) + tt.store %11, %20, %12 : tensor<8x1x!tt.ptr> loc(#loc25) + tt.return loc(#loc26) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":33:24) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":33:29) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":33:52) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":33:61) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":33:39) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":34:22) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":34:16) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":35:16) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":36:21) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":41:34) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":41:43) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":42:26) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":43:28) +#loc15 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":45:26) +#loc16 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":40:27) +#loc17 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":41:21) +#loc18 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":45:32) +#loc19 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":46:16) +#loc20 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":46:8) +#loc21 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":267:36) +#loc23 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":256:15) +#loc24 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":49:31) +#loc25 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":50:23) +#loc26 = loc("/home/zhengyang/git/flagtree/python/test/ops/sum_dim/sum_dim.py":50:4) +#loc27 = loc(callsite(#loc21 at #loc22)) +#loc29 = loc(callsite(#loc23 at #loc21)) +#loc30 = loc(callsite(#loc29 at #loc22)) diff --git a/python/test/ops/varmean/triton-ascend/var_mean_welford_kernel.ttadapter b/python/test/ops/varmean/triton-ascend/var_mean_welford_kernel.ttadapter new file mode 100644 index 000000000..09d88c680 --- /dev/null +++ b/python/test/ops/varmean/triton-ascend/var_mean_welford_kernel.ttadapter @@ -0,0 +1,137 @@ +#map = affine_map<(d0) -> (d0)> +module { + func.func @var_mean_welford_kernel(%arg0: memref, %arg1: memref, %arg2: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %arg3: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg4: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c4_i32 = arith.constant 4 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %0 = tensor.empty() : tensor<4xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32> + %2 = tensor.empty() : tensor<4x64xf32> + %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<4x64xf32>) -> tensor<4x64xf32> + %4 = linalg.fill ins(%cst : f32) outs(%2 : tensor<4x64xf32>) -> tensor<4x64xf32> + %5 = arith.muli %arg11, %c4_i32 : i32 + %6 = tensor.empty() : tensor<4xi32> + %7 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%6 : tensor<4xi32>) { + ^bb0(%out: i32): + %39 = linalg.index 0 : index + %40 = arith.index_cast %39 : index to i32 + linalg.yield %40 : i32 + } -> tensor<4xi32> + %expanded = tensor.expand_shape %7 [[0, 1]] output_shape [4, 1] : tensor<4xi32> into tensor<4x1xi32> + %8 = tensor.empty() : tensor<4x1xi32> + %9 = linalg.fill ins(%5 : i32) outs(%8 : tensor<4x1xi32>) -> tensor<4x1xi32> + %10 = arith.addi %9, %expanded : tensor<4x1xi32> + %11 = arith.index_cast %5 : i32 to index + %12 = arith.index_cast %arg6 : i32 to index + %13 = arith.muli %11, %12 : index + %reinterpret_cast = memref.reinterpret_cast %arg3 to offset: [%11], sizes: [4, 1], strides: [1, 1] : memref to memref<4x1xf32, strided<[1, 1], offset: ?>> + %reinterpret_cast_1 = memref.reinterpret_cast %arg4 to offset: [%11], sizes: [4, 1], strides: [1, 1] : memref to memref<4x1xf32, strided<[1, 1], offset: ?>> + %14 = linalg.fill ins(%arg5 : i32) outs(%8 : tensor<4x1xi32>) -> tensor<4x1xi32> + %15 = arith.cmpi slt, %10, %14 : tensor<4x1xi32> + %16 = tensor.empty() : tensor<64xi32> + %17 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%16 : tensor<64xi32>) { + ^bb0(%out: i32): + %39 = linalg.index 0 : index + %40 = arith.index_cast %39 : index to i32 + linalg.yield %40 : i32 + } -> tensor<64xi32> + %expanded_2 = tensor.expand_shape %17 [[0, 1]] output_shape [1, 64] : tensor<64xi32> into tensor<1x64xi32> + %18 = tensor.empty() : tensor<1x64xi32> + %19 = linalg.fill ins(%arg6 : i32) outs(%18 : tensor<1x64xi32>) -> tensor<1x64xi32> + %20 = tensor.empty() : tensor<4x64xi1> + %collapsed = tensor.collapse_shape %15 [[0, 1]] : tensor<4x1xi1> into tensor<4xi1> + %broadcasted = linalg.broadcast ins(%collapsed : tensor<4xi1>) outs(%20 : tensor<4x64xi1>) dimensions = [1] + %21:3 = scf.for %arg14 = %c0_i32 to %arg6 step %c64_i32 iter_args(%arg15 = %4, %arg16 = %4, %arg17 = %4) -> (tensor<4x64xf32>, tensor<4x64xf32>, tensor<4x64xf32>) : i32 { + %39 = linalg.fill ins(%arg14 : i32) outs(%18 : tensor<1x64xi32>) -> tensor<1x64xi32> + %40 = arith.addi %39, %expanded_2 : tensor<1x64xi32> + %41 = arith.cmpi slt, %40, %19 : tensor<1x64xi32> + %collapsed_10 = tensor.collapse_shape %41 [[0, 1]] : tensor<1x64xi1> into tensor<64xi1> + %broadcasted_11 = linalg.broadcast ins(%collapsed_10 : tensor<64xi1>) outs(%20 : tensor<4x64xi1>) dimensions = [0] + %42 = arith.andi %broadcasted, %broadcasted_11 : tensor<4x64xi1> + %43 = arith.index_cast %arg14 : i32 to index + %44 = arith.addi %13, %43 : index + %reinterpret_cast_12 = memref.reinterpret_cast %arg2 to offset: [%44], sizes: [4, 64], strides: [%12, 1] : memref to memref<4x64xf32, strided<[?, 1], offset: ?>> + %alloc = memref.alloc() : memref<4x64xf32> + %45 = arith.addi %11, %c4 : index + %46 = arith.index_cast %arg5 : i32 to index + %47 = arith.maxsi %11, %46 : index + %48 = arith.minsi %45, %47 : index + %49 = arith.subi %48, %11 : index + %50 = arith.addi %43, %c64 : index + %51 = arith.maxsi %43, %12 : index + %52 = arith.minsi %50, %51 : index + %53 = arith.subi %52, %43 : index + %54 = arith.minsi %49, %c4 : index + %55 = arith.minsi %53, %c64 : index + %56 = arith.cmpi slt, %54, %c4 : index + %57 = arith.cmpi slt, %55, %c64 : index + %58 = arith.ori %56, %57 : i1 + scf.if %58 { + linalg.fill ins(%cst : f32) outs(%alloc : memref<4x64xf32>) + } + %subview_13 = memref.subview %reinterpret_cast_12[0, 0] [%54, %55] [1, 1] : memref<4x64xf32, strided<[?, 1], offset: ?>> to memref> + %subview_14 = memref.subview %alloc[0, 0] [%54, %55] [1, 1] : memref<4x64xf32> to memref> + memref.copy %subview_13, %subview_14 : memref> to memref> + %59 = bufferization.to_tensor %alloc restrict writable : memref<4x64xf32> + %60 = arith.uitofp %42 : tensor<4x64xi1> to tensor<4x64xf32> + %61 = arith.addf %arg17, %60 : tensor<4x64xf32> + %62 = arith.maxnumf %61, %3 : tensor<4x64xf32> + %63 = arith.mulf %arg16, %arg17 : tensor<4x64xf32> + %64 = arith.addf %63, %59 : tensor<4x64xf32> + %65 = arith.divf %64, %62 : tensor<4x64xf32> + %66 = arith.subf %59, %65 : tensor<4x64xf32> + %67 = arith.subf %59, %arg16 : tensor<4x64xf32> + %68 = arith.mulf %66, %67 : tensor<4x64xf32> + %69 = arith.mulf %68, %60 : tensor<4x64xf32> + %70 = arith.addf %arg15, %69 : tensor<4x64xf32> + scf.yield %70, %65, %61 : tensor<4x64xf32>, tensor<4x64xf32>, tensor<4x64xf32> + } + %22 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32> + %reduced = linalg.reduce ins(%21#2 : tensor<4x64xf32>) outs(%22 : tensor<4xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %39 = arith.addf %in, %init : f32 + linalg.yield %39 : f32 + } + %23 = arith.mulf %21#1, %21#2 : tensor<4x64xf32> + %reduced_3 = linalg.reduce ins(%23 : tensor<4x64xf32>) outs(%22 : tensor<4xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %39 = arith.addf %in, %init : f32 + linalg.yield %39 : f32 + } + %24 = arith.maxnumf %reduced, %1 : tensor<4xf32> + %25 = arith.divf %reduced_3, %24 : tensor<4xf32> + %expanded_4 = tensor.expand_shape %25 [[0, 1]] output_shape [4, 1] : tensor<4xf32> into tensor<4x1xf32> + %broadcasted_5 = linalg.broadcast ins(%25 : tensor<4xf32>) outs(%2 : tensor<4x64xf32>) dimensions = [1] + %26 = arith.subf %21#1, %broadcasted_5 : tensor<4x64xf32> + %27 = arith.mulf %21#2, %26 : tensor<4x64xf32> + %28 = arith.mulf %27, %26 : tensor<4x64xf32> + %29 = arith.addf %21#0, %28 : tensor<4x64xf32> + %reduced_6 = linalg.reduce ins(%29 : tensor<4x64xf32>) outs(%22 : tensor<4xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %39 = arith.addf %in, %init : f32 + linalg.yield %39 : f32 + } + %30 = arith.subi %arg6, %arg7 : i32 + %31 = arith.sitofp %30 : i32 to f32 + %32 = linalg.fill ins(%31 : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32> + %33 = arith.divf %reduced_6, %32 : tensor<4xf32> + %expanded_7 = tensor.expand_shape %33 [[0, 1]] output_shape [4, 1] : tensor<4xf32> into tensor<4x1xf32> + %34 = arith.addi %11, %c4 : index + %35 = arith.index_cast %arg5 : i32 to index + %36 = arith.maxsi %11, %35 : index + %37 = arith.minsi %34, %36 : index + %38 = arith.subi %37, %11 : index + %extracted_slice = tensor.extract_slice %expanded_4[0, 0] [%38, 1] [1, 1] : tensor<4x1xf32> to tensor + %subview = memref.subview %reinterpret_cast_1[0, 0] [%38, 1] [1, 1] : memref<4x1xf32, strided<[1, 1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor, memref>) -> () + %extracted_slice_8 = tensor.extract_slice %expanded_7[0, 0] [%38, 1] [1, 1] : tensor<4x1xf32> to tensor + %subview_9 = memref.subview %reinterpret_cast[0, 0] [%38, 1] [1, 1] : memref<4x1xf32, strided<[1, 1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice_8 in writable %subview_9 : (tensor, memref>) -> () + return + } +} + diff --git a/python/test/ops/varmean/triton-ascend/var_mean_welford_kernel.ttir b/python/test/ops/varmean/triton-ascend/var_mean_welford_kernel.ttir new file mode 100644 index 000000000..7aa768e8d --- /dev/null +++ b/python/test/ops/varmean/triton-ascend/var_mean_welford_kernel.ttir @@ -0,0 +1,145 @@ +#loc = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":17:0) +#loc1 = loc(unknown) +#loc32 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":55:25) +#loc35 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":58:26) +#loc43 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":68:17) +#loc51 = loc(callsite(#loc1 at #loc32)) +#loc54 = loc(callsite(#loc1 at #loc35)) +#loc56 = loc(callsite(#loc1 at #loc43)) +module { + tt.func public @var_mean_welford_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":17:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":17:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":17:0), %arg3: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":17:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":17:0), %arg5: i32 loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":17:0)) attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+00> : tensor<4xf32> loc(#loc1) + %cst_0 = arith.constant dense<1.000000e+00> : tensor<4x64xf32> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_1 = arith.constant dense<0.000000e+00> : tensor<4x64xf32> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = arith.muli %0, %c4_i32 : i32 loc(#loc3) + %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc4) + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> loc(#loc5) + %4 = tt.splat %1 : i32 -> tensor<4x1xi32> loc(#loc6) + %5 = arith.addi %4, %3 : tensor<4x1xi32> loc(#loc6) + %6 = tt.splat %arg4 : i32 -> tensor<4x1xi32> loc(#loc7) + %7 = arith.muli %5, %6 : tensor<4x1xi32> loc(#loc7) + %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x1x!tt.ptr> loc(#loc8) + %9 = tt.addptr %8, %7 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> loc(#loc8) + %10 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> loc(#loc9) + %11 = tt.addptr %10, %5 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> loc(#loc9) + %12 = tt.splat %arg2 : !tt.ptr -> tensor<4x1x!tt.ptr> loc(#loc10) + %13 = tt.addptr %12, %5 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> loc(#loc10) + %14 = tt.splat %arg3 : i32 -> tensor<4x1xi32> loc(#loc11) + %15 = arith.cmpi slt, %5, %14 : tensor<4x1xi32> loc(#loc11) + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc12) + %17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc13) + %18 = tt.splat %arg4 : i32 -> tensor<1x64xi32> loc(#loc14) + %19 = tt.broadcast %15 : tensor<4x1xi1> -> tensor<4x64xi1> loc(#loc15) + %20 = tt.broadcast %9 : tensor<4x1x!tt.ptr> -> tensor<4x64x!tt.ptr> loc(#loc16) + %21:3 = scf.for %arg6 = %c0_i32 to %arg4 step %c64_i32 iter_args(%arg7 = %cst_1, %arg8 = %cst_1, %arg9 = %cst_1) -> (tensor<4x64xf32>, tensor<4x64xf32>, tensor<4x64xf32>) : i32 { + %39 = tt.splat %arg6 : i32 -> tensor<1x64xi32> loc(#loc18) + %40 = arith.addi %39, %17 : tensor<1x64xi32> loc(#loc18) + %41 = arith.cmpi slt, %40, %18 : tensor<1x64xi32> loc(#loc14) + %42 = tt.broadcast %41 : tensor<1x64xi1> -> tensor<4x64xi1> loc(#loc15) + %43 = arith.andi %19, %42 : tensor<4x64xi1> loc(#loc15) + %44 = tt.broadcast %40 : tensor<1x64xi32> -> tensor<4x64xi32> loc(#loc16) + %45 = tt.addptr %20, %44 : tensor<4x64x!tt.ptr>, tensor<4x64xi32> loc(#loc16) + %46 = tt.load %45, %43, %cst_1 : tensor<4x64x!tt.ptr> loc(#loc19) + %47 = arith.uitofp %43 : tensor<4x64xi1> to tensor<4x64xf32> loc(#loc20) + %48 = arith.addf %arg9, %47 : tensor<4x64xf32> loc(#loc20) + %49 = arith.maxnumf %48, %cst_0 : tensor<4x64xf32> loc(#loc21) + %50 = arith.mulf %arg8, %arg9 : tensor<4x64xf32> loc(#loc22) + %51 = arith.addf %50, %46 : tensor<4x64xf32> loc(#loc23) + %52 = arith.divf %51, %49 : tensor<4x64xf32> loc(#loc24) + %53 = arith.subf %46, %52 : tensor<4x64xf32> loc(#loc25) + %54 = arith.subf %46, %arg8 : tensor<4x64xf32> loc(#loc26) + %55 = arith.mulf %53, %54 : tensor<4x64xf32> loc(#loc27) + %56 = arith.mulf %55, %47 : tensor<4x64xf32> loc(#loc28) + %57 = arith.addf %arg7, %56 : tensor<4x64xf32> loc(#loc29) + scf.yield %57, %52, %48 : tensor<4x64xf32>, tensor<4x64xf32>, tensor<4x64xf32> loc(#loc30) + } loc(#loc17) + %22 = "tt.reduce"(%21#2) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32 loc(callsite(#loc1 at #loc32)), %arg7: f32 loc(callsite(#loc1 at #loc32))): + %39 = arith.addf %arg6, %arg7 : f32 loc(#loc57) + tt.reduce.return %39 : f32 loc(#loc50) + }) : (tensor<4x64xf32>) -> tensor<4xf32> loc(#loc50) + %23 = arith.mulf %21#1, %21#2 : tensor<4x64xf32> loc(#loc34) + %24 = "tt.reduce"(%23) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32 loc(callsite(#loc1 at #loc35)), %arg7: f32 loc(callsite(#loc1 at #loc35))): + %39 = arith.addf %arg6, %arg7 : f32 loc(#loc58) + tt.reduce.return %39 : f32 loc(#loc53) + }) : (tensor<4x64xf32>) -> tensor<4xf32> loc(#loc53) + %25 = arith.maxnumf %22, %cst : tensor<4xf32> loc(#loc36) + %26 = arith.divf %24, %25 : tensor<4xf32> loc(#loc37) + %27 = tt.expand_dims %26 {axis = 1 : i32} : tensor<4xf32> -> tensor<4x1xf32> loc(#loc38) + %28 = tt.broadcast %27 : tensor<4x1xf32> -> tensor<4x64xf32> loc(#loc39) + %29 = arith.subf %21#1, %28 : tensor<4x64xf32> loc(#loc39) + %30 = arith.mulf %21#2, %29 : tensor<4x64xf32> loc(#loc40) + %31 = arith.mulf %30, %29 : tensor<4x64xf32> loc(#loc41) + %32 = arith.addf %21#0, %31 : tensor<4x64xf32> loc(#loc42) + %33 = "tt.reduce"(%32) <{axis = 1 : i32}> ({ + ^bb0(%arg6: f32 loc(callsite(#loc1 at #loc43)), %arg7: f32 loc(callsite(#loc1 at #loc43))): + %39 = arith.addf %arg6, %arg7 : f32 loc(#loc59) + tt.reduce.return %39 : f32 loc(#loc55) + }) : (tensor<4x64xf32>) -> tensor<4xf32> loc(#loc55) + %34 = arith.subi %arg4, %arg5 : i32 loc(#loc44) + %35 = arith.sitofp %34 : i32 to f32 loc(#loc45) + %36 = tt.splat %35 : f32 -> tensor<4xf32> loc(#loc45) + %37 = arith.divf %33, %36 : tensor<4xf32> loc(#loc45) + %38 = tt.expand_dims %37 {axis = 1 : i32} : tensor<4xf32> -> tensor<4x1xf32> loc(#loc46) + tt.store %13, %27, %15 : tensor<4x1x!tt.ptr> loc(#loc47) + tt.store %11, %38, %15 : tensor<4x1x!tt.ptr> loc(#loc48) + tt.return loc(#loc49) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":28:24) +#loc3 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":28:29) +#loc4 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":28:52) +#loc5 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":28:61) +#loc6 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":28:39) +#loc7 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":29:18) +#loc8 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":29:12) +#loc9 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":30:16) +#loc10 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":31:18) +#loc11 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":32:21) +#loc12 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":39:34) +#loc13 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":39:43) +#loc14 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":40:26) +#loc15 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":41:28) +#loc16 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":42:24) +#loc17 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":38:27) +#loc18 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":39:21) +#loc19 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":42:30) +#loc20 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":44:25) +#loc21 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":45:32) +#loc22 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":46:28) +#loc23 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":46:37) +#loc24 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":46:42) +#loc25 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":47:21) +#loc26 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":47:38) +#loc27 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":47:34) +#loc28 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":47:47) +#loc29 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":47:16) +#loc30 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":49:8) +#loc31 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":267:36) +#loc33 = loc("/usr/local/python3.11.13/lib/python3.11/site-packages/triton/language/standard.py":256:15) +#loc34 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":58:34) +#loc36 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":59:50) +#loc37 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":59:26) +#loc38 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":63:25) +#loc39 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":67:49) +#loc40 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":67:41) +#loc41 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":67:67) +#loc42 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":67:31) +#loc44 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":70:21) +#loc45 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":70:17) +#loc46 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":72:14) +#loc47 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":75:25) +#loc48 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":76:23) +#loc49 = loc("/home/zhengyang/git/flagtree/python/test/ops/varmean/var_mean_ascend.py":76:4) +#loc50 = loc(callsite(#loc31 at #loc32)) +#loc52 = loc(callsite(#loc33 at #loc31)) +#loc53 = loc(callsite(#loc31 at #loc35)) +#loc55 = loc(callsite(#loc31 at #loc43)) +#loc57 = loc(callsite(#loc52 at #loc32)) +#loc58 = loc(callsite(#loc52 at #loc35)) +#loc59 = loc(callsite(#loc52 at #loc43)) diff --git a/python/test/ops/varmean/var_mean.py b/python/test/ops/varmean/var_mean.py new file mode 100644 index 000000000..c8b585330 --- /dev/null +++ b/python/test/ops/varmean/var_mean.py @@ -0,0 +1,127 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit +def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y): + count = count_x + count_y + _count = tl.maximum(count, 1) + mc_x = mean_x * count_x + mc_y = mean_y * count_y + mean = (mc_x + mc_y) / _count + M = M_x + mc_x * mean_x + M_y + mc_y * mean_y - count * mean * mean + return mean, count, M + + +@triton.jit(do_not_specialize=["correction"]) +def var_mean_welford_kernel( + X, + Var, + Mean, + M, + N, + correction, + BLOCK_M: tl.constexpr = 4, + BLOCK_N: tl.constexpr = 64, +): + # Map the program id to the row of X it should compute. + pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + X = X + pid * N + Var = Var + pid + Mean = Mean + pid + row_mask = pid < M + + _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + for off in range(0, N, BLOCK_N): + cols = off + tl.arange(0, BLOCK_N)[None, :] + col_mask = cols < N + mask = row_mask and col_mask + + x = tl.load(X + cols, mask, other=0.0).to(tl.float32) + + count = _count + mask + cnt = tl.maximum(count, 1) + cur_mean = (_mean * _count + x) / cnt + _acc += (x - cur_mean) * (x - _mean) * mask + _mean = cur_mean + _count = count + + mean, _, acc = tl.reduce((_mean, _count, _acc), axis=1, combine_fn=welford_func) + var = acc / (N - correction) + mean = mean[:, None] + var = var[:, None] + # Write mean / var + tl.store(Mean, mean, row_mask) + tl.store(Var, var, row_mask) + + +def var_mean(x, dim=None, *, correction=None, keepdim=False): + if correction is None: + correction = 1.0 + + if dim is None or len(dim) == x.ndim: + assert False + else: + shape = list(x.shape) + dim = [d % x.ndim for d in dim] + N = 1 + for i in dim: + N *= shape[i] + shape[i] = 1 + M = x.numel() // N + var = torch.empty(shape, dtype=x.dtype, device=x.device) + mean = torch.empty(shape, dtype=x.dtype, device=x.device) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) + with torch_device_fn.device(x.device): + var_mean_welford_kernel[grid](x, var, mean, M, N, correction) + + if not keepdim: + var = var.squeeze(dim=dim) + mean = mean.squeeze(dim=dim) + return var, mean + + +if __name__ == "__main__": + # param + shape = (2, 32) + dim = [1] + correction = 1 + keepdim = True + + # inp + inp = torch.randn(shape, dtype=torch.float32, device=device) + ref_inp = inp.cpu() + + # op + ref_var, ref_mean = torch.var_mean( + ref_inp, dim, correction=correction, keepdim=keepdim + ) + res_var, res_mean = var_mean(inp, dim, correction=correction, keepdim=keepdim) + + # check + res_var = res_var.cpu() + print( + f"The maximum difference var between torch and triton is " + f"{torch.max(torch.abs(ref_var - res_var))}" + ) + assert torch.allclose(res_var, ref_var), (res_var, ref_var) + res_mean = res_mean.cpu() + print( + f"The maximum difference mean between torch and triton is " + f"{torch.max(torch.abs(ref_mean - res_mean))}" + ) + assert torch.allclose(res_mean, ref_mean), (res_mean, ref_mean) diff --git a/python/test/ops/varmean/var_mean_ascend.py b/python/test/ops/varmean/var_mean_ascend.py new file mode 100644 index 000000000..ce4ed814a --- /dev/null +++ b/python/test/ops/varmean/var_mean_ascend.py @@ -0,0 +1,135 @@ +import torch +import triton +import triton.language as tl + +# active driver +driver = triton.runtime.driver.active +# torch.cuda, torch.aipu, torch.npu +torch_device_fn = triton.runtime.driver.active.get_device_interface() +# device +if hasattr(driver, "get_active_torch_device"): + device = triton.runtime.driver.active.get_active_torch_device() +else: + device = triton.runtime.driver.active.get_current_device() + + +@triton.jit(do_not_specialize=["correction"]) +def var_mean_welford_kernel( + X, + Var, + Mean, + M, + N, + correction, + BLOCK_M: tl.constexpr = 4, + BLOCK_N: tl.constexpr = 64, +): + # Map the program id to the row of X it should compute. + pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] + X = X + pid * N + Var = Var + pid + Mean = Mean + pid + row_mask = pid < M + + _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + _acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for off in range(0, N, BLOCK_N): + cols = off + tl.arange(0, BLOCK_N)[None, :] + col_mask = cols < N + mask = row_mask and col_mask + x = tl.load(X + cols, mask, other=0.0).to(tl.float32) + + count = _count + mask + cnt = tl.maximum(count, 1) + cur_mean = (_mean * _count + x) / cnt + _acc += (x - cur_mean) * (x - _mean) * mask + _mean = cur_mean + _count = count + + # 手动实现 tl.reduce 的功能,沿着 axis=1 进行归约 + # 使用 tl.sum 来进行归约,这等价于 welford 算法在这种情况下的行为 + + # 计算每行的总计数 + total_count = tl.sum(_count, axis=1) # shape: (BLOCK_M,) + + # 计算加权平均值 + weighted_sum = tl.sum(_mean * _count, axis=1) # shape: (BLOCK_M,) + mean = weighted_sum / tl.maximum(total_count, 1) # shape: (BLOCK_M,) + + # 计算方差累积值 + # 对于每个元素,计算其对总体方差的贡献 + mean_expanded = mean[:, None] # shape: (BLOCK_M, 1) + + # 计算每个局部统计量对总体方差的贡献 + # 这是 Welford 算法的并行化版本 + local_var_contrib = _acc + _count * (_mean - mean_expanded) * (_mean - mean_expanded) + acc = tl.sum(local_var_contrib, axis=1) # shape: (BLOCK_M,) + + var = acc / (N - correction) + mean = mean[:, None] + var = var[:, None] + + # Write mean / var + tl.store(Mean, mean, row_mask) + tl.store(Var, var, row_mask) + + +def var_mean(x, dim=None, *, correction=None, keepdim=False): + if correction is None: + correction = 1.0 + + if dim is None or len(dim) == x.ndim: + assert False + else: + shape = list(x.shape) + dim = [d % x.ndim for d in dim] + N = 1 + for i in dim: + N *= shape[i] + shape[i] = 1 + M = x.numel() // N + var = torch.empty(shape, dtype=x.dtype, device=x.device) + mean = torch.empty(shape, dtype=x.dtype, device=x.device) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) + with torch_device_fn.device(x.device): + var_mean_welford_kernel[grid](x, var, mean, M, N, correction) + + if not keepdim: + var = var.squeeze(dim=dim) + mean = mean.squeeze(dim=dim) + return var, mean + + +if __name__ == "__main__": + # param + shape = (2, 32) + dim = [1] + correction = 1 + keepdim = True + + # inp + inp = torch.randn(shape, dtype=torch.float32, device=device) + ref_inp = inp.cpu() + + # op + ref_var, ref_mean = torch.var_mean( + ref_inp, dim, correction=correction, keepdim=keepdim + ) + res_var, res_mean = var_mean(inp, dim, correction=correction, keepdim=keepdim) + + # check + res_var = res_var.cpu() + print( + f"The maximum difference var between torch and triton is " + f"{torch.max(torch.abs(ref_var - res_var))}" + ) + assert torch.allclose(res_var, ref_var), (res_var, ref_var) + res_mean = res_mean.cpu() + print( + f"The maximum difference mean between torch and triton is " + f"{torch.max(torch.abs(ref_mean - res_mean))}" + ) + assert torch.allclose(res_mean, ref_mean), (res_mean, ref_mean) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index d8ca58d8d..155c2824e 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -155,12 +155,10 @@ def visit_Return(self, node: ast.Return) -> bool: def visit_Assign(self, node: ast.Assign) -> bool: # There couldn't be an early return - # x = ... return False def visit_AugAssign(self, node: ast.AugAssign) -> bool: # There couldn't be an early return - # x += ... return False def visit_Module(self, node: ast.Module) -> bool: @@ -170,13 +168,6 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: return self._visit_stmts(node.body) def visit_If(self, node: ast.If) -> bool: - # TODO: optimize the following case in which we actually don't have - # a return when static_cond is false: - # if dynamic_cond - # if static_cond - # func_with_return - # else - # func_without_return ret = self._visit_stmts(node.body) if node.orelse: ret = ret or self._visit_stmts(node.orelse) @@ -201,9 +192,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.begin_line = begin_line - 1 self.builder.set_loc(file_name, begin_line, 0) self.builder.options = options - # dict of functions provided by the backend. Below are the list of possible functions: - # Convert custom types not natively supported on HW. - # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) self.builder.codegen_fns = codegen_fns self.builder.module_map = {} if module_map is None else module_map self.module = self.builder.create_module() if module is None else module @@ -486,6 +474,7 @@ def visit_AnnAssign(self, node): return self.visit_Assign(node) def visit_Assign(self, node): + # flagtree: First, do normal assignment processing _names = [] if isinstance(node, ast.AnnAssign): _names += [self.visit(node.target)] @@ -509,6 +498,30 @@ def visit_Assign(self, node): not isinstance(value, native_nontensor_types): value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) + + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = self.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and + isinstance(node.value, ast.Call) and + isinstance(node.value.func, ast.Attribute) and + isinstance(node.value.func.value, ast.Name) and + node.value.func.value.id == 'tl' and + node.value.func.attr == 'load'): + + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = self.builder.get_unit_attr() + self.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) def visit_AugAssign(self, node): name = node.target.id @@ -815,8 +828,6 @@ def visit_While(self, node): liveins, insert_block = sr ip, last_loc = self._get_insertion_point_and_loc() - # loop body (the after region) - # loop_block = self.builder.create_block() dummy = self.builder.create_block() self.builder.set_insertion_point_to_start(dummy) self.scf_stack.append(node) @@ -910,7 +921,8 @@ def visit_For(self, node): return num_stages = None loop_unroll_factor = None - if IteratorClass is language.range: + bind_sub_block = None + if IteratorClass in [language.range, language.parallel]: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments # note: only `range` iterator is supported now @@ -920,6 +932,8 @@ def visit_For(self, node): step = iterator.step num_stages = iterator.num_stages loop_unroll_factor = iterator.loop_unroll_factor + if (IteratorClass is language.parallel): + bind_sub_block = iterator.bind_sub_block elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -929,6 +943,20 @@ def visit_For(self, node): step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) else: raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + + # flagtree: After normal processing, check if we need to override bind_sub_block + if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = self.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a range/for loop with bind_sub_block hint + if flagtree_hints and 'bind_sub_block' in flagtree_hints: + bind_sub_block = True + # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") + # handle negative constant step (not supported by scf.for in MLIR) negative_step = False if _is_constexpr(step) and step.value < 0: @@ -992,6 +1020,8 @@ def visit_For(self, node): for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) if loop_unroll_factor is not None: for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) + if (bind_sub_block is not None) and bind_sub_block: + for_op.set_attr("bind_sub_block", self.builder.get_bool_attr(bind_sub_block)) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) @@ -1075,7 +1105,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): generator.visit(fn.parse()) except Exception as e: # Wrap the error in the callee with the location of the call. - raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e callee_ret_type = generator.ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -1119,7 +1149,7 @@ def visit_Call(self, node): # itself). But when calling a function, we raise as `from e` to # preserve the traceback of the original error, which may e.g. # be in core.py. - raise CompilationError(self.jit_fn.src, node, None) from e + raise CompilationError(self.jit_fn.src, node, repr(e)) from e if fn in self.builtin_namespace.values(): args = map(_unwrap_if_constexpr, args) diff --git a/python/triton/compiler/code_generator.py.std b/python/triton/compiler/code_generator.py.std new file mode 100644 index 000000000..d8ca58d8d --- /dev/null +++ b/python/triton/compiler/code_generator.py.std @@ -0,0 +1,1303 @@ +import ast +import inspect +import re +import sys +import warnings +import os +import textwrap +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from .. import language +from .._C.libtriton import ir +from ..language import constexpr, tensor, str_to_ty +from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value +from ..runtime.jit import _normalize_ty, get_jit_fn_file_line +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) +from types import ModuleType + + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + raise TypeError(f'Unsupported type {ty}') + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_value(o: Any) -> bool: + return isinstance(o, _value) + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return isinstance(o, constexpr) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + return any(self.visit(s) for s in body) + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) is ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + self.context = context + self.builder = ir.builder(context) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.builder.module_map = {} if module_map is None else module_map + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + + self.gscope = {} + for k, v in gscope.items(): + if isinstance(v, ModuleType): + self.gscope[k] = module_map.get(v.__name__, v) + continue + + module_name = getattr(v, "__module__", "") + if module_name in module_map: + self.gscope[k] = getattr(module_map[module_name], v.__name__) + else: + self.gscope[k] = v + + self.lscope = {} + self.attributes = attributes + self.constants = constants + self.jit_fn = jit_fn + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + if a := self.gscope.get("__annotations__", {}).get(name): + return _normalize_ty(a) == "constexpr" + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if any([ + val is absent, name in self.builtin_namespace, # + type(val) is ModuleType, # + isinstance(val, JITFunction), # + getattr(val, "__triton_builtin__", False), # + getattr(val, "__module__", "").startswith("triton.language"), # + isinstance(val, language.dtype), # + self._is_constexpr_global(name), # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + self.visiting_arg_default_value, # + os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1" + ]): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are annotated as constexpr (`x: triton.language.constexpr = 42` + or `x = triton.language.constexpr(42)`). Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + if ret_value is None: + self.builder.ret([]) + ret_ty = language.void + elif isinstance(ret_value, tuple): + ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + ret_ty = tuple(ret_types) + else: + ret = language.semantic.to_tensor(ret_value, self.builder) + self.builder.ret([ret.handle]) + ret_ty = ret.type + + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + # A return op must always terminate the basic block, so we create a dead + # basic block in case there are any ops after the return. + post_ret_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(post_ret_block) + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults[::-1]): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = [] + idx = 0 + for i in range(len(arg_names)): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + for name, value in self.attributes[i]: + self.fn.set_arg_attr(idx, name, value) + + # Mark this argument as a pass-by-value TMA descriptor (nvidia) + if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): + self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) + + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + + # finalize function + assert not self.builder.get_insertion_block().has_terminator() + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) + self.builder.ret([ + self.builder.create_poison(ty.to_ir(self.builder)) + for ty in self.prototype.ret_types + if self.ret_type is not None + ]) + self.fn.finalize() + + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not _is_constexpr(value): + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + if isinstance(node, ast.AnnAssign): + _names += [self.visit(node.target)] + else: + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise self._unsupported(node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not _is_list_like(names): + names = [names] + if not _is_list_like(values): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): + value = language.semantic.to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + ret_types = [] + ir_ret_types = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + assert defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name].type}, '\ + f'but the {block_name} block redefines it as {defs[name].type}' + if name in then_defs or name in else_defs: + names.append(name) + ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in sorted(then_defs.keys() & else_defs.keys()): + if name in names: + continue + then_ty = then_defs[name].type + else_ty = else_defs[name].type + assert then_ty == else_ty, \ + f'Mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + ret_types.append(then_ty) + ir_ret_types.append(then_defs[name].handle.get_type()) + + return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + + def visit_if_top_level(self, cond, node): + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create basic-block after conditional + endif_block = self.builder.create_block() + # then terminator + self.builder.set_insertion_point_to_end(then_block) + assert not then_block.has_terminator(), f"{then_block}" + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + assert not else_block.has_terminator(), f"{else_block}" + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + for ty in ir_ret_types: + endif_block.add_argument(ty) + + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + # update values + for i, name in enumerate(names): + new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) + self.set_value(name, new_tensor) + + def visit_If(self, node): + cond = self.visit(node.test) + + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if contains_return: + if self.scf_stack: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + self.visit_if_top_level(cond, node) + else: + self.visit_if_scf(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + + active_block = node.body if cond else node.orelse + self.visit_compound_statement(active_block) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = language.semantic.to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) is ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) is ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_builder=self.builder) + try: + return getattr(operand, fn)() + except AttributeError: + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def _verify_loop_carried_variable(self, name, loop_val, live_val): + assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop' + assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type' + assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ + f'Loop-carried variable {name} has initial type {live_val.type} '\ + f'but is re-assigned to {loop_val.type} in loop! '\ + f'Please make sure that the type stays consistent.' + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + loop_val = loop_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + # these are loop-carried values + names.append(name) + ret_types.append(loop_val.type) + init_args.append(live_val) + + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + self.builder.set_insertion_point_to_start(before_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + yields.append(loop_defs[name]) + self.builder.create_yield_op([y.handle for y in yields]) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + loop_unroll_factor = None + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + loop_unroll_factor = iterator.loop_unroll_factor + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = language.semantic.to_tensor(lb, self.builder) + ub = language.semantic.to_tensor(ub, self.builder) + step = language.semantic.to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_poison(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + loop_val = self.local_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + names.append(name) + init_args.append(live_val) + yields.append(loop_val) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + if num_stages is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + if loop_unroll_factor is not None: + for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) + + self.scf_stack.append(node) + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder)) + + # create YieldOp + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] + # generate function def + attributes = {} + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = fn.__globals__ + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = get_jit_fn_file_line(fn) + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, + module_map=self.builder.module_map) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = {"_builder": self.builder} + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + return fn(*args, **extra_kwargs, **kws) + except Exception as e: + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, None) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + return fn(*args, **kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise self._unsupported( + node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + if sys.version_info < (3, 8): + + def visit_NameConstant(self, node): + return constexpr(node.value) + + def visit_Num(self, node): + return constexpr(node.n) + + def visit_Str(self, node): + return constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs) and node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisibility_16: + suffix += 'd' + return suffix + + +def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): + attrs = specialization.attrs + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in specialization.constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = fn.repr(specialization) + tys = list(specialization.signature.values()) + new_constants = attrs.get_constants() + for k in new_constants: + if k in tys and tys[k] == "i1" and new_constants[k] == 1: + new_constants[k] = True + + new_attrs = attrs.filter_out_constants() + fn_attrs = new_attrs.get_fn_attrs() + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + file_name, begin_line = get_jit_fn_file_line(fn) + + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) + generator.visit(fn.parse()) + + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 8ca1f8b32..6a8359d6f 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -4,6 +4,7 @@ from .._C.libtriton import get_cache_invalidating_env_vars, ir from ..backends import backends from ..backends.compiler import GPUTarget, AttrsDescriptor +from ..backends.ascend.compiler import AscendAttrsDescriptor from .. import __version__ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager @@ -11,6 +12,7 @@ from ..tools.disasm import get_sass # TODO: this shouldn't be here from .code_generator import ast_to_ttir +from .errors import MLIRCompilationError from pathlib import Path import re import functools @@ -86,7 +88,7 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: if not isinstance(k, str): raise TypeError("Constants keys must be string") if self.attrs is None: - self.attrs = AttrsDescriptor() + self.attrs = AscendAttrsDescriptor() def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] @@ -250,6 +252,12 @@ def compile(src, target=None, options=None): # cache hit! metadata = json.loads(Path(metadata_path).read_text()) return CompiledKernel(src, metadata_group, hash) + compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') + if (compile_speed_opt): + ttir_path = f"{file_name}.ttir" + if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): + # Already compile once but failed. So directly return + raise Exception("already failed once") # initialize metadata metadata = { "hash": hash, @@ -276,7 +284,17 @@ def compile(src, target=None, options=None): raise use_ir_loc = os.environ.get("USE_IR_LOC", None) for ext, compile_ir in list(stages.items())[first_stage:]: - next_module = compile_ir(module, metadata) + try: + next_module = compile_ir(module, metadata) + except Exception as e: + if (ext == "ttadapter"): + stage_name = "ConvertTritonIRToLinalgIR" + elif (ext == "npubin"): + stage_name = "ConvertLinalgRToBinary" + else: + stage_name = "MLIRCompile" + error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) + raise MLIRCompilationError(stage_name, error_detail) ir_filename = f"{file_name}.{ext}" if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): print(f"\nOverriding kernel with file {full_name}") @@ -330,7 +348,6 @@ def add(self, func, args): class AsmDict(dict): def __missing__(self, key): - if key == "sass": value = get_sass(self["cubin"]) else: @@ -390,10 +407,8 @@ def _init_handles(self): self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( self.name, self.kernel, self.metadata.shared, device) - def __getattribute__(self, name): - if name == 'run': - self._init_handles() - return super().__getattribute__(name) + # This mechanism introduces heavy runtime overhead. + # Commenting __getattribute__ requires explicitly calling _init_handles() def launch_metadata(self, grid, stream, *args): if CompiledKernel.launch_enter_hook is None: @@ -416,6 +431,8 @@ def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): + if stream is None: + stream = self.metadata.stream if stream is None: device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) diff --git a/python/triton/compiler/compiler.py.std b/python/triton/compiler/compiler.py.std new file mode 100644 index 000000000..8ca1f8b32 --- /dev/null +++ b/python/triton/compiler/compiler.py.std @@ -0,0 +1,426 @@ +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir +from ..backends import backends +from ..backends.compiler import GPUTarget, AttrsDescriptor +from .. import __version__ +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.driver import driver +from ..tools.disasm import get_sass +# TODO: this shouldn't be here +from .code_generator import ast_to_ttir +from pathlib import Path +import re +import functools +import os + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + tma = re.search(r'tt.nv_tma_desc = 1', x) + if tma is not None: + return 'nvTmaDesc' + x = re.sub(r' {[^}]+}', '', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if + # e.g. someone has an instruction (not module) attribute named "num-warps". + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + num_warps = int(num_warps_matches[0]) + return num_warps + + +class ASTSource: + + def __init__(self, fn, signature, constants=None, attrs=None) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.attrs = attrs + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + else: + for k in self.signature.keys(): + if not isinstance(k, str): + raise TypeError("Signature keys must be string") + if self.constants is None: + self.constants = {} + else: + for k in self.constants.keys(): + if not isinstance(k, str): + raise TypeError("Constants keys must be string") + if self.attrs is None: + self.attrs = AttrsDescriptor() + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + # Note - we stringify the keys here to allow sorting to work for cases + # where constants have mixed int/str keys. + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def parse_options(self): + if self.ext == "ttgir": + return {'num_warps': _get_num_warps_from_ir_str(self.src)} + return dict() + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.walk_packages([language_path], prefix="triton.language."): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx": + return Path(full_name).read_text() + if ext == "cubin": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1": + return + + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +def compile(src, target=None, options=None): + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. + # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". + # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate + # the file name to 150 characters to be safe. + file_name = src.name[:150] + metadata_filename = f"{file_name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" + if not always_compile and metadata_path is not None: + # cache hit! + metadata = json.loads(Path(metadata_path).read_text()) + return CompiledKernel(src, metadata_group, hash) + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() + module_map = backend.get_module_map() + try: + module = src.make_ir(options, codegen_fns, module_map, context) + except Exception as e: + filter_traceback(e) + raise + use_ir_loc = os.environ.get("USE_IR_LOC", None) + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + ir_filename = f"{file_name}.{ext}" + if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): + print(f"\nOverriding kernel with file {full_name}") + next_module = parse(full_name, ext, context) + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + # use an env variable to parse ir from file + if use_ir_loc == ext: + ir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ir_full_name) + print(f"Creating new locations for {ir_full_name}") + module = next_module + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # Compilation completed, disabling multithreading in context. + # This is needed to safely finalize threads pool inside context: if current process forks before + # python GC deletes context object, thread pool in child process will be invalid, which could + # lead to child crash or hang. + context.disable_multithreading() + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target): + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self) -> None: + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class AsmDict(dict): + + def __missing__(self, key): + + if key == "sass": + value = get_sass(self["cubin"]) + else: + raise KeyError("Unknown key: '%s'" % key) + + self[key] = value + return value + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = AsmDict({ + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + }) + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + def _init_handles(self): + if self.module is not None: + return + device = driver.active.get_current_device() + # create launcher + self.run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + + def __getattribute__(self, name): + if name == 'run': + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if CompiledKernel.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + if i in self.src.fn.constexprs: + arg_dict[arg_name] = self.src.constants[arg_name] + else: + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) + + return runner diff --git a/python/triton/compiler/errors.py b/python/triton/compiler/errors.py index 39e6c4dfb..5242258ad 100644 --- a/python/triton/compiler/errors.py +++ b/python/triton/compiler/errors.py @@ -49,3 +49,20 @@ class CompileTimeAssertionFailure(CompilationError): class UnsupportedLanguageConstruct(CompilationError): pass + + +class MLIRCompilationError(TritonError): + def __init__(self, stage_name: Optional[str], message: Optional[str] = None): + self.stage_name = stage_name + self.message = f"\n" \ + f"{self.format_line_delim('[ERROR][Triton][BEG]')}" \ + f"[{self.stage_name}] encounters error:\n" \ + f"{self.filter_message(message)}" \ + f"{self.format_line_delim('[ERROR][Triton][END]')}" + def __str__(self): + return self.message + def filter_message(self, message): + # Content starting from "Stack dump without symbol names" means nothing to the users + return message.split("Stack dump without symbol names")[0] + def format_line_delim(self, keyword): + return f"///------------------{keyword}------------------\n" \ No newline at end of file diff --git a/python/triton/compiler/errors.py.std b/python/triton/compiler/errors.py.std new file mode 100644 index 000000000..39e6c4dfb --- /dev/null +++ b/python/triton/compiler/errors.py.std @@ -0,0 +1,51 @@ +import ast +from typing import Optional +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/python/triton/language/_utils.py b/python/triton/language/_utils.py index b9aa69071..d0ca8c734 100644 --- a/python/triton/language/_utils.py +++ b/python/triton/language/_utils.py @@ -1,21 +1,39 @@ -from typing import List - -TRITON_MAX_TENSOR_NUMEL = 1048576 +from __future__ import annotations +from typing import List, TYPE_CHECKING, Any, Union, Dict -def is_power_of_two(x): - return (x & (x - 1)) == 0 +if TYPE_CHECKING: + from .language import core + IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] + ObjPath = tuple[int, ...] +TRITON_MAX_TENSOR_NUMEL = 1048576 def validate_block_shape(shape: List[int]): numel = 1 for i, d in enumerate(shape): if not isinstance(d, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") - if not is_power_of_two(d): - raise ValueError(f"Shape element {i} must be a power of 2") numel *= d if numel > TRITON_MAX_TENSOR_NUMEL: raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") return numel + + +BITWIDTH_DICT: Dict[str, int] = { + **{f"u{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"i{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"fp{n}": n + for n in (16, 32, 64)}, + **{f"fp8{suffix}": 8 + for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")}, + "bf16": 16, + "void": 0, +} + + +def get_primitive_bitwidth(dtype: str) -> int: + return BITWIDTH_DICT[dtype] diff --git a/python/triton/language/_utils.py.std b/python/triton/language/_utils.py.std new file mode 100644 index 000000000..b9aa69071 --- /dev/null +++ b/python/triton/language/_utils.py.std @@ -0,0 +1,21 @@ +from typing import List + +TRITON_MAX_TENSOR_NUMEL = 1048576 + + +def is_power_of_two(x): + return (x & (x - 1)) == 0 + + +def validate_block_shape(shape: List[int]): + numel = 1 + for i, d in enumerate(shape): + if not isinstance(d, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + if not is_power_of_two(d): + raise ValueError(f"Shape element {i} must be a power of 2") + numel *= d + + if numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + return numel diff --git a/python/triton/language/core_ext.py b/python/triton/language/core_ext.py new file mode 100644 index 000000000..f44af20c2 --- /dev/null +++ b/python/triton/language/core_ext.py @@ -0,0 +1,478 @@ +import os +from typing import List, Sequence, Optional, Union + +from triton._C.libtriton import ir +from triton.language import semantic as real_semantic +from triton.language.core import ( + _constexpr_to_value, + _tensor_member_fn, + _unwrap_iterable, + builtin, + constexpr, + dtype as real_dtype, + float32, + tensor, + check_bit_width, + _unwrap_if_constexpr, + range, + add, + sub, + mul, +) +from typing import Optional +from . import semantic_ext as semantic +from .tensor_descriptor import tensor_descriptor, tensor_descriptor_base + + +@_tensor_member_fn +@builtin +def cast(input, dtype: real_dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, overflow_mode: Optional[str] = None, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :type dtype: tl.dtype + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :type fp_downcast_rounding: str, optional + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + :type bitcast: bool, optional + :param overflow_mode: When overflow_mode is not set or is "trunc", + truncation (cut-off) will be used to handle overflow. When + overflow_mode is "sautrate", the maximum value of the data type + will be used to handle overflow. + :type overflow_mode: string, optional + """ + overflow_modes = ["trunc", "saturate"] + input = semantic.to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + ret = semantic.cast(input, dtype, _builder, fp_downcast_rounding) + if overflow_mode is not None: + if overflow_mode in overflow_modes: + semantic.compile_hint(ret, "overflow_mode", overflow_mode, _builder) + else: + raise ValueError(f"Unknown overflow_mode:{overflow_mode} is found.") + return ret + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + If the parameter :code:`dims` is not specified, the function defaults to a (1,0) permutation, + effectively transposing a 2D tensor. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + if not dims: + dims = (1, 0) + dims = _unwrap_iterable(dims) + return real_semantic.permute(input, dims, _builder) + + +@builtin +def dot( + input, + other, + acc=None, + input_precision=None, + allow_tf32=None, + max_num_imprecise_acc=None, + out_dtype=float32, + _builder=None, +): + assert ( + input_precision is None or allow_tf32 is None + ), "Only one of input_precision and allow_tf32 can be specified" + assert ( + not allow_tf32 + ), "allow_tf32 is deprecated, please use input_precision='hf32' on Ascend instead." + if input_precision is None: + supports_tf32 = ( + _builder and "tf32" in _builder.options.allowed_dot_input_precisions + ) + default_precision = ( + "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + ) + input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) + else: + assert input_precision not in [ + "tf32", + "tf32x3", + ], "input_precision == tf32 or tf32x3 is invalid, please use input_precision='hf32' on Ascend instead." + input_precision = _constexpr_to_value(input_precision) + out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot( + input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder + ) + + +@_tensor_member_fn +@builtin +def gather(src, index, axis, _builder=None): + """Gather from a tensor along a given dimension. + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param axis: the dimension to gather along + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.gather(src, index, axis, _builder) + + +@_tensor_member_fn +@builtin +def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to receive tensor. + :type ful: Tensor + :param sub: The tensor to be inserted. + :type sub: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + assert len(ful.shape) == len(sub.shape) + new_offsets = [ + real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + out = semantic.insert_slice(ful, sub, new_offsets, sizes, strides, _builder) + return out + + +@_tensor_member_fn +@builtin +def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to split. + :type ful: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + new_offsets = [ + real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + sub = semantic.extract_slice(ful, new_offsets, sizes, strides, _builder) + return sub + +@_tensor_member_fn +@builtin +def get_element(src, indice, _builder=None, _generator=None): + """ + get_element op reads a ranked tensor and returns one element as specified by the given indices. + The result of the op is a value with the same type as the elements of the tensor. + The arity of indices must match the rank of the accessed value. + + :param src: The tensor to be accessed. + :type src: Tensor + :param indice: + :type indice: tuple of ints + """ + assert len(src.shape) > 0 + new_indice = [ + real_semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i + for i in indice + ] + return semantic.get_element(src, new_indice, _builder) + +@builtin +def __add__(self, other, _builder=None): + return add(self, other, sanitize_overflow=False, _builder=_builder) + +@builtin +def __radd__(self, other, _builder=None): + return add(other, self, sanitize_overflow=False, _builder=_builder) + +@builtin +def __sub__(self, other, _builder=None): + return sub(self, other, sanitize_overflow=False, _builder=_builder) + +@builtin +def __rsub__(self, other, _builder=None): + return sub(other, self, sanitize_overflow=False, _builder=_builder) + +@builtin +def __mul__(self, other, _builder=None): + return mul(self, other, sanitize_overflow=False, _builder=_builder) + +@builtin +def __rmul__(self, other, _builder=None): + return mul(other, self, sanitize_overflow=False, _builder=_builder) + + +@builtin +def __lshift__(self, other, _builder=None): + if self.type.scalar.is_floating(): + raise TypeError(f"unexpected type {self.type.scalar}") + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + return semantic.shl(self, other, _builder) + + +@builtin +def __rshift__(self, other, _builder=None): + if self.type.scalar.is_floating(): + raise TypeError(f"unexpected type {self.type.scalar}") + other = _unwrap_if_constexpr(other) + check_bit_width(self, other) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + + +class parallel(range): + """ + Iterator that counts upward forever, with parallel execution semantics. + + This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param bind_sub_block: Tells the compiler if multiple vector cores participate in the loop. + This is used in the mixed cube-vector kernel on 910B. The number of vector cores is determined by the number of + iteration in this loop. Currently on 910B, max 2 vector cores could be used. + """ + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, bind_sub_block: bool = False): + super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) + self.bind_sub_block = bind_sub_block + + +@builtin +def compile_hint(ptr, hint_name, hint_val=None, _builder=None): + def _unwrap(val): + return _unwrap_if_constexpr(val) if val else val + + hint_name = _constexpr_to_value(hint_name) + assert isinstance(hint_name, str), f"hint name: {hint_name} is not string" + if isinstance(hint_val, list): + hint_val = [_unwrap(val) for val in hint_val] + else: + hint_val = _unwrap(hint_val) + hint_val = _unwrap_if_constexpr(hint_val) if hint_val else hint_val + semantic.compile_hint(ptr, hint_name, hint_val, _builder) + + +@builtin +def sort(ptr, dim=-1, descending=False, _builder=None): + """ + Triton sort 前端接口 + + 参数: + ptr: tl.tensor,输入张量 + dim: int 或 tl.constexpr[int],排序维度 + descending: bool 或 tl.constexpr[bool],是否降序 + _builder: ir.builder,底层 IR 构建器 + 返回: + values: tl.tensor,排序后的值(类型与输入一致) + """ + + try: + dim = int(dim.value) if hasattr(dim, "value") else int(dim) + except Exception as e: + raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}. Error: {str(e)}") from e + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + ret = semantic.sort(ptr, dim, descending, _builder) + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty.is_int8() or base_ty.is_int16(): + semantic.compile_hint(ret, "overflow_mode", constexpr("saturate"), _builder) + return ret + + +@builtin +def multibuffer(src: tensor, size, _builder=None): + """ + Set multi_buffer for an existing tensor + :src: tensor set to bufferize multiple time + :size: number of copies + """ + buffer_size = _constexpr_to_value(size) + assert isinstance(buffer_size, int) and buffer_size == 2, f"only support bufferize equals 2" + semantic.compile_hint(src, "multi_buffer", buffer_size, _builder) + + +@builtin +def sync_block_all(mode, event_id, _builder=None): + mode = _constexpr_to_value(mode) + event_id = _constexpr_to_value(event_id) + assert isinstance(mode, str), f"mode: {mode} is not string" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + assert mode == "all_cube" or mode == "all_vector" or mode == "all", f"ERROR: mode = {mode}, only supports all_cube/all_vector/all" + semantic.custom_op(_builder, "sync_block_all", mode=mode, event_id=event_id) + + +@builtin +def sync_block_set(sender, receiver, event_id, _builder=None): + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + event_id = _constexpr_to_value(event_id) + assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + semantic.custom_op(_builder, "sync_block_set", sender=sender, event_id=event_id) + + +@builtin +def sync_block_wait(sender, receiver, event_id, _builder=None): + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + event_id = _constexpr_to_value(event_id) + assert isinstance(sender, str) and (sender == "cube" or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + semantic.custom_op(_builder, "sync_block_wait", sender=sender, event_id=event_id) + + +@builtin +def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], + _builder=None) -> tensor: + """Load a block of data from a tensor descriptor.""" + return desc.load(offsets, _builder=_builder) + + +@builtin +def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[Union[constexpr, tensor]], value: tensor, + _builder=None) -> tensor: + """Store a block of data to a tensor descriptor.""" + return desc.store(offsets, value, _builder=_builder) + + +@builtin +def make_tensor_descriptor( + base: tensor, + shape: List[tensor], + strides: List[tensor], + block_shape: List[constexpr], + _builder=None, +) -> tensor_descriptor: + """Make a tensor descriptor object + + :param base: the base pointer of the tensor, must be 16-byte aligned + :param shape: A list of non-negative integers representing the tensor shape + :param strides: A list of tensor strides. Leading dimensions must be multiples + of 16-byte strides and the last dimension must be contiguous. + :param block_shape: The shape of block to be loaded/stored from global memory + + Notes + ***** + On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object + and loads and stores from the descriptor will be backed by the TMA hardware. + + Currently only 2-5 dimensional tensors are supported. + + Example + ******* + .. code-block:: python + + @triton.jit + def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl.make_tensor_descriptor( + in_out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + + value = desc.load([moffset, noffset]) + desc.store([moffset, noffset], tl.abs(value)) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + M, N = 256, 256 + x = torch.randn(M, N, device="cuda") + M_BLOCK, N_BLOCK = 32, 32 + grid = (M // M_BLOCK, N // N_BLOCK) + inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK) + + """ + return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder) + + +def dtype_to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + raise ValueError(f'unexpected type fp8.') + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + diff --git a/python/triton/language/math_ext.py b/python/triton/language/math_ext.py new file mode 100644 index 000000000..667263df4 --- /dev/null +++ b/python/triton/language/math_ext.py @@ -0,0 +1,174 @@ +from functools import wraps +from typing import List +from triton.language import core +from triton.language.math import _add_math_1arg_docstr, _add_math_2arg_docstr, _add_math_3arg_docstr +from triton.language import semantic + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + arg_type = arg.type.scalar.name + if hasattr(arg, 'was_bool_to_int8') and arg.was_bool_to_int8: + # In Triton, int1 maps to the boolean type + arg_type = 'int1' + if arg_type not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg_type}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +@core.builtin +@_check_dtype(dtypes=["int32", "uint32"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") +def div_rn(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def tanh(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_tanh(x.handle), x.type) + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + z = semantic.to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) + diff --git a/python/triton/language/semantic_ext.py b/python/triton/language/semantic_ext.py new file mode 100644 index 000000000..8ef3e2ec9 --- /dev/null +++ b/python/triton/language/semantic_ext.py @@ -0,0 +1,821 @@ +from typing import List, Optional, Union, Tuple +import numbers +import triton.language as tl +from triton._C.libtriton import ir +from triton.language.semantic import wrap_tensor, _str_to_rounding_mode, not_equal, _str_to_dot_input_precision, \ + binary_op_type_checking_impl, integer_promote_impl, broadcast_impl_shape, _str_to_sem, _str_to_scope, bitcast, \ + bitwise_op_type_checking_impl, shl, ashr, lshr, fdiv, sub, mul, to_tensor +import triton.language.math as math +import triton.language.core as core +from triton.language._utils import TRITON_MAX_TENSOR_NUMEL + +from .tensor_descriptor import ( + _unwrap_if_constexpr, + _unwrap_shape, + block_type, + tensor_descriptor +) + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + if range > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"end - start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}") + shape = [range] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) + +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, + fp_downcast_rounding: Optional[str] = None) -> tl.tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty == dst_sca_ty: + return input + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + + if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): + raise ValueError("[fp8, fp64] is unsupported on Ascend for now." + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + # All combinations of supported fp8 x fp8 are permitted + pass + else: + assert lhs.dtype in (tl.int1, tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported lhs dtype {lhs.dtype}" + assert rhs.dtype in (tl.int1, tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported rhs dtype {rhs.dtype}" + assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" + + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + lhs = cast(lhs, tl.float16, builder) + rhs = cast(rhs, tl.float16, builder) + + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + input_precision = _str_to_dot_input_precision(input_precision, builder) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert builder.codegen_fns.get("min_dot_size") is not None, "target doesn't provide lower shape bounds for dot." + min_dot_size = builder.codegen_fns["min_dot_size"](lhs.type, rhs.type) + assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \ + and rhs.shape[-1].value >= min_dot_size[1], \ + f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + K = lhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): + if (not lhs.dtype.is_fp32() or not rhs.dtype.is_fp32() or not ret_scalar_ty.is_fp32()): + raise ValueError("input_precision = 'hf32' must be used with f32 * f32 = f32 on Ascend") + + if max_num_imprecise_acc is not None: + tl.static_print("max_num_imprecise_acc is not supported on Ascend yet. Thus it is ignored.") + max_num_imprecise_acc = 0 + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), + ret_ty) + +# Use Union instead of |. Becase python 3.9 does not support |. +# It will reports error: TypeError: unsupported operand type(s) for |: 'type' and 'ABCMeta' +def floordiv(input: Union[tl.tensor, numbers.Number], other: Union[tl.tensor, numbers.Number], builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if hasattr(input, 'was_bool_to_int8'): + if input.type.scalar.is_int8(): + raise TypeError(f"unexpected type bool") + if hasattr(other, 'was_bool_to_int8'): + if other.type.scalar.is_int8(): + raise TypeError(f"unexpected type bool") + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def mod(input: Union[tl.tensor, numbers.Number], other: Union[tl.tensor, numbers.Number], builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if hasattr(input, 'was_bool_to_int8'): + if input.type.scalar.is_int8(): + raise TypeError(f"unexpected type bool") + if hasattr(other, 'was_bool_to_int8'): + if other.type.scalar.is_int8(): + raise TypeError(f"unexpected type bool") + # float + if scalar_ty.is_floating(): + floor = math.floor(fdiv(input, other, False, builder), _builder=builder) + ret = sub(input, mul(floor, other, True, builder), True, builder) + return ret + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if hasattr(input, 'was_bool_to_int8'): + if input.type.scalar.is_int8(): + raise TypeError(f"unexpected type bool") + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, True, builder) + + +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if input.type.scalar.is_floating(): + raise TypeError(f"unexpected type {input.type.scalar}") + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if input.type.scalar.is_floating(): + raise TypeError(f"unexpected type {input.type.scalar}") + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if input.type.scalar.is_floating(): + raise TypeError(f"unexpected type {input.type.scalar}") + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + assert index.dtype.is_int(), "index must be an integer tensor" + if not src.dtype.is_floating(): + raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {src.dtype}") + + rank = len(src.type.shape) + assert len(index.type.shape) == rank, "source and index tensors must have the same rank" + + assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" + if axis < 0: + axis += rank + + for d in range(rank): + if d == axis: + continue + assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" + + gather = builder.create_gather(src.handle, index.handle, axis) + return wrap_tensor(gather, src.type.scalar, index.type.shape) + +def insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, ful.shape) + out = builder.create_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + + +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + if hasattr(input, 'was_bool_to_int8'): + assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." + input = cast(input, tl.int1, builder) + input_sca_ty = input.type.scalar + if input_sca_ty.is_floating(): + raise TypeError(f"unexpected type {input_sca_ty}") + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if hasattr(input, 'was_bool_to_int8'): + assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." + input = cast(input, tl.int1, builder) + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if hasattr(other, 'was_bool_to_int8'): + assert other.type.scalar.is_int8(), "Other input wat bool to int8. However, other input.type is not int8." + other = cast(other, tl.int1, builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return and_(input, other, builder) + + +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if hasattr(input, 'was_bool_to_int8'): + assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." + input = cast(input, tl.int1, builder) + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if hasattr(other, 'was_bool_to_int8'): + assert other.type.scalar.is_int8(), "Other wat bool to int8. However, other.type is not int8." + other = cast(other, tl.int1, builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return or_(input, other, builder) + + +def not_(input: tl.tensor, builder: ir.builder): + if hasattr(input, 'was_bool_to_int8'): + assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." + input = cast(input, tl.int1, builder) + if input.type.scalar.is_floating(): + raise TypeError(f"unexpected type {input.type.scalar}") + return invert(input, builder) + + +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + if other is None: + other = to_tensor(0, builder) + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other is not None: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + is_bool = elt_ty == tl.int1 + if is_bool: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast `other` into `elt_ty` type + if other is not None: + other = cast(other, elt_ty, builder) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + ret = tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) + # Do not cast back to int1 when is_bool=true. We directly use the int8 tensor given by tl.load + if is_bool: + ret.was_bool_to_int8 = True + + return ret + +def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_bool(): + raise TypeError(f"Unexpected dtype {dtype}") + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + +def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_bool(): + raise TypeError(f"Unexpected dtype {dtype}") + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + +def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, sizes) + out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + +def get_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): + if len(src.shape) != len(indice): + raise ValueError("Indice's rank must be equal to src tensor's rank") + + new_indice = [i.handle for i in indice] + result = builder.create_extract_scalar(src.handle, new_indice) + return wrap_tensor(result, src.type.scalar, None) + +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + element_ty = ptr.type.scalar.element_ty + # Add `tl.int64` restriction for NPU + if element_ty in [tl.int1, tl.int64, tl.float16, tl.float32, tl.float64, tl.bfloat16] and op in ['or', 'xor']: + raise ValueError(f"atomic_{op} does not support {str(element_ty)}. " + "All support dtypes are int8, int16, int32.") + if element_ty in [tl.int1, tl.int64, tl.float64, tl.bfloat16] and op == 'xchg': + raise ValueError(f"atomic_{op} does not support {str(element_ty)}. " + "All support dtypes are int8, int16, int32, float16, float32.") + if element_ty in [tl.int1, tl.int64, tl.float64]: + raise ValueError(f"atomic_{op} does not support {str(element_ty)}. " + "All support dtypes are int8, int16, int32, float16, float32, bfloat16.") + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val is not None: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if element_ty in [tl.int1, tl.int8, tl.float64, tl.bfloat16]: + raise ValueError(f"atomic_cas does not support {str(element_ty)}. " + "All support dtypes are int16, int32, int64, float16, float32.") + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + # Design for NPU + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + # Design for NPU + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + +def compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): + if not hint_val: + hint_val = builder.get_unit_attr() + elif isinstance(hint_val, bool): + hint_val = builder.get_bool_attr(hint_val) + elif isinstance(hint_val, int): + hint_val = builder.get_int32_attr(hint_val) + elif isinstance(hint_val, core.constexpr): + hint_val = builder.get_str_attr(hint_val.value) + elif isinstance(hint_val, list): + # only support i64 array attr for now + hint_val = builder.get_i64_array_attr(hint_val) + else: + raise ValueError(f"Unsupported hint value type: {type(hint_val)}") + builder.create_annotation(ptr.handle, hint_name, hint_val) + + +def custom_op(builder: ir.builder, op_name: str, **kwargs): + if op_name == "sync_block_all": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["mode"], kwargs["event_id"]) + + elif op_name == "sync_block_set": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) + + elif op_name == "sync_block_wait": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) + + raise ValueError(f"Unsupported custom op: {op_name}") + + +def sort(ptr: tl.tensor, dim: int, descending, builder: ir.builder): + """ + Triton sort 操作 + + 参数: + ptr: tl.tensor,输入张量 + dim: int,排序维度,必须是尾轴(最后一维) + descending: bool 或 constexpr,是否降序 + builder: ir.builder,底层 IR 构建器 + 返回: + values: tl.tensor,排序后的值(类型与输入一致) + """ + + allowed_types = {tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32} + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty not in allowed_types: + raise TypeError( + f"tt.sort only supports int8, int16, bfloat16, float16, float32, " + f"but got {ptr.type}" + ) + + shape = getattr(ptr, "shape", None) + if shape is None or shape == (): + shape = getattr(getattr(ptr, "type", None), "shape", None) + + rank = None + if shape is not None: + try: + rank = len(shape) + except Exception: + rank = len(list(shape)) + + if rank is not None: + if rank < 1: + raise ValueError("tt.sort requires tensor rank >= 1") + last_dim = rank - 1 + norm_dim = dim if dim >= 0 else dim + rank + if norm_dim != last_dim: + raise ValueError( + f"tt.sort only supports sorting along the last dimension " + f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}" + ) + dim = last_dim + else: + if dim != -1: + raise ValueError( + "tt.sort only supports the last dimension; when rank is unknown " + "you must pass dim=-1" + ) + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + sorted_vals = builder.create_sort(ptr.handle, dim, descending) + + values = tl.tensor(sorted_vals, type=ptr.type) + + return values + + +def _str_to_fp_type(float_format: Optional[str]): + if float_format == 'e4m3': + return ir.F8F6F4TY.E4M3 + if float_format == 'e5m2': + return ir.F8F6F4TY.E5M2 + if float_format == 'e2m3': + return ir.F8F6F4TY.E2M3 + if float_format == 'e3m2': + return ir.F8F6F4TY.E3M2 + if float_format == 'e2m1': + return ir.F8F6F4TY.E2M1 + if float_format == 'bf16': + return ir.F8F6F4TY.BF16 + if float_format == 'fp16': + return ir.F8F6F4TY.FP16 + raise ValueError(f"Invalid float format: {float_format}.") + + +def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return bitcast(val, triton_ty, builder) + + +def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], + rhs_format: str, acc: Union[tl.tensor, None], out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" + assert rhs.dtype == tl.bfloat16 or rhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" + assert lhs.dtype == rhs.dtype, f"lhs rhs matrix must get same dtype" + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value + lhs_format_enum = _str_to_fp_type(lhs_format) + rhs_format_enum = _str_to_fp_type(rhs_format) + allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" + assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" + assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" + rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) + lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) + assert isinstance(lhs_scale, tl.tensor) and lhs_scale.dtype == tl.int8, f"lhs_scale must be int8 tensor" + if not rhs_scale_is_none: + assert isinstance(rhs_scale, tl.tensor) and rhs_scale.dtype == tl.int8, f"rhs_scale must be int8 tensor" + lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) + rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) + + assert lhs.type.shape[-1] == rhs.type.shape[-2], ( + f"lhs last dimension (columns) {lhs.shape[-1]} " + f"must equal rhs penultimate dimension (rows) {rhs.shape[-2]}" + ) + M = lhs.type.shape[-2] + K, N = rhs.type.shape[-2:] + PACKED_A = 2 if lhs_format == "e2m1" else 1 + PACKED_B = 2 if lhs_format == "e2m1" else 1 + assert K * PACKED_B == PACKED_A * lhs.type.shape[ + -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + B = lhs.type.shape[0] if lhs_rank == 3 else None + + ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) + _0 = builder.get_fp32(0) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle + lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle + return tl.tensor( + builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle, + rhs_format_enum, acc_handle), ret_ty) + + + + +def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + return tl.tensor(value, dtype) + + +def make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + return cast(value, dtype, builder) + return scalar_constant(value, dtype, builder) + + +def make_tensor_descriptor( + base: tl.tensor, + shape: List[tl.tensor], + strides: List[tl.tensor], + block_shape: List[tl.constexpr], + builder: ir.builder +) -> tensor_descriptor: + ndim = len(shape) + if not (1 <= ndim <= 5): + raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") + assert isinstance(base.dtype, tl.pointer_type) + primitive_bitwidth = base.dtype.element_ty.primitive_bitwidth + if primitive_bitwidth == 1: + raise ValueError("int1 type is not supported for make_tensor_descriptor yet") + elem_size = primitive_bitwidth // 8 + contig_dim_size = _unwrap_if_constexpr(block_shape[-1]) + if contig_dim_size * elem_size < 16: + raise ValueError( + f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" + ) + + strides[-1] = _unwrap_if_constexpr(strides[-1]) + if strides[-1] != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}") + + shape = [make_scalar(x, tl.int32, builder) for x in shape] + strides = [make_scalar(x, tl.int64, builder) for x in strides] + + block_shape = _unwrap_shape(block_shape) + + assert isinstance(base.type, tl.pointer_type) + desc_block_type = block_type(base.type.element_ty, block_shape) + base_handle = base.handle + is_signed_int = base.type.element_ty.is_int_signed() + + handle = builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape], + [s.handle for s in strides], block_shape, is_signed_int) + return tensor_descriptor(handle, shape, strides, desc_block_type) diff --git a/python/triton/language/standard_ext.py b/python/triton/language/standard_ext.py new file mode 100644 index 000000000..64f26670c --- /dev/null +++ b/python/triton/language/standard_ext.py @@ -0,0 +1,100 @@ +from math import pi as math_pi +from triton.language import core, math +from triton.language import float32, int1 +from triton.language.standard import max, sum +from triton.runtime.jit import jit +from triton.language.extra.ascend.libdevice import flip as ascend_flip + + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along (currently only final dimension supported) + :type dim: int + """ + return ascend_flip(x, dim) + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + _is_int8_type: core.constexpr = x.dtype.is_int8() + core.static_assert(not _is_int8_type, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") + _is_floating_type: core.constexpr = x.dtype.is_floating() + core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") + return (1 / (1 + math.exp(-x.to(core.float32)))).to(x.dtype) + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("softmax") +def softmax(x, ieee_rounding=False): + _is_int8_type: core.constexpr = x.dtype.is_int8() + core.static_assert(not _is_int8_type, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") + _is_floating_type: core.constexpr = x.dtype.is_floating() + core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") + z = x.to(core.float32) - max(x, 0) + num = math.exp(z) + den = sum(num, 0) + return math.fdiv(num, den, ieee_rounding).to(x.dtype) + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("isfinited") +def isfinited(x): + _is_int8_type: core.constexpr = x.dtype.is_int8() + core.static_assert(not _is_int8_type, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") + _is_floating_type: core.constexpr = x.dtype.is_floating() + core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") + nan_mask = math.isnan(x) + inf_mask = math.isinf(x) + return (~nan_mask & ~inf_mask).to(int1) + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("finitef") +def finitef(x): + _is_int8_type: core.constexpr = x.dtype.is_int8() + core.static_assert(not _is_int8_type, f"finitef only supports float32, but got int8 or int1") + core.static_assert(x.dtype == float32, f"finitef only supports float32, but got {core.constexpr(x.dtype)}") + nan_mask = math.isnan(x) + inf_mask = math.isinf(x) + return (~nan_mask & ~inf_mask).to(int1) + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("rint") +def rint(x): + _is_int8_type: core.constexpr = x.dtype.is_int8() + core.static_assert(not _is_int8_type, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") + _is_floating_type: core.constexpr = x.dtype.is_floating() + core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") + return core.where(x >= 0, math.floor(x + 0.5), math.ceil(x - 0.5)) + +pi: core.constexpr = math_pi + +@core._tensor_member_fn +@jit +@math._add_math_2arg_docstr("atan2") +def atan2(y, x): + _is_int8_type_x: core.constexpr = x.dtype.is_int8() + core.static_assert(not _is_int8_type_x, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") + _is_int8_type_y: core.constexpr = y.dtype.is_int8() + core.static_assert(not _is_int8_type_y, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") + _is_floating_type_x: core.constexpr = x.dtype.is_floating() + core.static_assert(_is_floating_type_x == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") + _is_floating_type_y: core.constexpr = y.dtype.is_floating() + core.static_assert(_is_floating_type_y == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(y.dtype)}") + half_pi: core.constexpr = 0.5 * pi + base = core.where(x == 0, 0.0, math.atan(y.to(float32) / x.to(float32))) + base = core.where((x == 0) & (y > 0), half_pi, base) + base = core.where((x == 0) & (y < 0), -half_pi, base) + + add_pi = core.where((x < 0) & (y >= 0), pi, 0.0) + sub_pi = core.where((x < 0) & (y < 0), -pi, 0.0) + return (base + add_pi + sub_pi).to(x.dtype) diff --git a/python/triton/language/tensor_descriptor.py b/python/triton/language/tensor_descriptor.py new file mode 100644 index 000000000..077c889e1 --- /dev/null +++ b/python/triton/language/tensor_descriptor.py @@ -0,0 +1,692 @@ +# TODO: When upgrading to Triton 3.4.0, remove this file, +# use the upstream Triton functions, and update core.py and semantic.py accordingly. +from __future__ import annotations + +import builtins +from typing import List, Tuple, Sequence, TypeVar +from enum import Enum + +from triton._C.libtriton import ir +from triton.language.core import ( + builtin, + constexpr, + tensor, + _value, + void as real_void, +) + +from triton.language.semantic import ( + _convert_to_ir_values, + _str_to_load_cache_modifier, + _str_to_eviction_policy, +) + +from ._utils import validate_block_shape, get_primitive_bitwidth + + +def _unwrap_if_constexpr(o): + if isinstance(o, list): + return [_unwrap_if_constexpr(x) for x in o] + if isinstance(o, builtins.tuple): + return builtins.tuple(_unwrap_if_constexpr(x) for x in o) + if isinstance(o, tuple): + return tuple(_unwrap_if_constexpr(x) for x in o) + return o.value if isinstance(o, constexpr) else o + + +def _unwrap_shape(shape): + shape = _unwrap_if_constexpr(shape) + return [_unwrap_if_constexpr(s) for s in shape] + + +def _normalize_tuple(t): + normalized_tuple = _unwrap_if_constexpr(t) + if isinstance(normalized_tuple, (list, builtins.tuple)): + normalized_tuple = tuple(normalized_tuple) + return normalized_tuple + + +def descriptor_load(desc: tensor_descriptor_base, offsets, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tensor: + assert isinstance(desc, tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier), + _str_to_eviction_policy(eviction_policy)) + return tensor(x, desc.block_type) + + +def validate_store_like(desc: tensor_descriptor_base, value: tensor, offsets) -> None: + assert isinstance(desc, tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + assert value.shape == desc.block_shape + + +def descriptor_store(desc: tensor_descriptor_base, value: tensor, offsets, builder: ir.builder) -> tensor: + validate_store_like(desc, value, offsets) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + return tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), real_void) + + + +class base_value(_value): + """Base class of values that exist in the triton IR (i.e. not constexprs). + """ + type: base_type + + def _flatten_ir(self, handles: List[ir.value]) -> None: + """Flatten frontend value into a sequence of mlir handles, which are appended + to the output list + """ + raise NotImplementedError + + +class base_type: + + def __eq__(self, other): + raise NotImplementedError("Types must implement __eq__") + + def __ne__(self, other): + return not (self == other) + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + """Build a frontend value with the current dtype, wrapping a list of existing handles. + cursor is the index of the first handle relevant to this value, and the function + should return the updated cursor position after any handles consumed by the created value. + """ + raise NotImplementedError + + def mangle(self) -> str: + raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}") + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + raise NotImplementedError + + +class tuple(base_value): + + def __init__(self, args: Sequence, type: tuple_type = None): + self.values = [i for i in args] + + def get_type(x): + if isinstance(x, dtype): + return dtype + if isinstance(x, (int, float)): + return constexpr + return x.type + + self.type = type or tuple_type([get_type(x) for x in self.values]) + + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + assert isinstance(idx, (slice, builtins.slice)) + return tuple(self.values[idx.start:idx.stop:idx.step]) + + def __getattr__(self, name): + return self.values[self.type.fields.index(name)] + + def __setitem__(self, idx: constexpr, value): + if isinstance(idx, int): + idx = constexpr(idx) + assert isinstance(idx, constexpr) + self.values[idx] = value + + def __add__(self, other): + other = _normalize_tuple(other) + return tuple(self.values + other.values) + + def __mul__(self, other): + assert isinstance(other, constexpr) + return tuple(self.values * other.value) + + def __eq__(self, other): + other = _normalize_tuple(other) + return constexpr(self.values == other.values) + + def __hash__(self): + return hash(builtins.tuple(self.values)) + + def __str__(self): + return str([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + def __len__(self): + return len(self.values) + + def _flatten_ir(self, handles: List[ir.value]): + for v in self.values: + print("[debug]tuple _flatten_ir: value:", v) + v._flatten_ir(handles) + print("[debug]tuple _flatten_ir: handles:", handles) + + def __repr__(self): + return f"({' ,'.join(repr(x) for x in self.values)})" + + +class tuple_type(base_type): + + def __init__(self, types, fields=None): + self.types = types + self.fields = fields or [''] * len(types) + self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' + + def __str__(self): + return self.name + + def __iter__(self): + return iter(self.types) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]): + for ty in self.types: + if not isinstance(ty, constexpr): + ty._flatten_ir_types(builder, out) + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def __eq__(self, other): + return type(self) is type(other) and self.types == other.types and self.fields == other.fields + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: + values = [] + for ty in self.types: + value, cursor = ty._unflatten_ir(handles, cursor) + values.append(value) + return tuple(values, self), cursor + + def mangle(self): + return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T' + + +class dtype(base_type): + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + class KIND(Enum): + BOOLEAN = 0 + INTEGRAL = 1 + FLOATING = 2 + + def __init__(self, name): + name = _unwrap_if_constexpr(name) + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + self.primitive_bitwidth = get_primitive_bitwidth(name) + self.itemsize = self.primitive_bitwidth // 8 + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 52 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + def kind(self): + # Return int value following the type ordering bool < integer < fp + if self.is_bool(): + return dtype.KIND.BOOLEAN + elif self.is_int(): + return dtype.KIND.INTEGRAL + else: + assert self.is_floating() + return dtype.KIND.FLOATING + + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + out.append(self.to_ir(builder)) + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + if self.name not in builder.options.supported_fp8_dtypes: + raise ValueError(f'type {self} not supported in this architecture. ' + f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + return tensor(handles[cursor], self), cursor + 1 + + def mangle(self) -> str: + if self.is_int(): + SIGNED = dtype.SIGNEDNESS.SIGNED + prefix = 'i' if self.int_signedness == SIGNED else 'u' + return prefix + str(self.int_bitwidth) + if self.is_floating(): + return str(self) + if self.is_void(): + return 'V' + return super().mangle() + + def with_element_ty(self, element_ty: dtype): + assert not self.is_block() + return element_ty + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + assert (isinstance(shape, (list, tuple))) + + # shape can be empty ([]) when an input is a 0D tensor. + self.shape = tuple(_unwrap_shape(shape)) + if not self.shape: + raise TypeError('0d block_type is forbidden') + + self.numel = validate_block_shape(self.shape) + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> Tuple[int]: + return self.shape + + def with_element_ty(self, scalar_ty: dtype) -> block_type: + return block_type(scalar_ty, self.shape) + + def __eq__(self, other) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + @property + def scalar(self): + return self.element_ty + + def mangle(self) -> str: + elt = self.scalar.mangle() + shape = '_'.join(map(str, self.shape)) + return f'{elt}S{shape}S' + + +class tuple_type(base_type): + + def __init__(self, types, fields=None): + self.types = types + self.fields = fields or [''] * len(types) + self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' + + def __str__(self): + return self.name + + def __iter__(self): + return iter(self.types) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]): + for ty in self.types: + if not isinstance(ty, constexpr): + ty._flatten_ir_types(builder, out) + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def __eq__(self, other): + return type(self) is type(other) and self.types == other.types and self.fields == other.fields + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: + values = [] + for ty in self.types: + value, cursor = ty._unflatten_ir(handles, cursor) + values.append(value) + return tuple(values, self), cursor + + def mangle(self): + return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T' + + +class tensor_descriptor_base_type(base_type): + + def __init__(self, block_type: block_type): + self.block_type = block_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + value = tensor_descriptor_base(handles[cursor], self.block_type) + return value, cursor + 1 + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + is_signed = self.block_type.element_ty.is_int_signed() + out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed)) + + def __str__(self) -> str: + # ex. "tensor_descriptor" + return f"tensor_descriptor<{self.block_type}>" + + def __eq__(self, other) -> bool: + if type(other) is not type(self): + return False + return self.block_type == other.block_type + + def __neq__(self, other) -> bool: + return not (self == other) + + def mangle(self) -> str: + return f"TD{self.block_type.mangle()}" + + +class tensor_descriptor_base(base_value): + """" + A tensor descriptor with unknown shape and strides + """ + + def __init__(self, handle, block_type: block_type): + """Not called by user code.""" + super().__init__(handle) + + self.handle = handle # IR handle + self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + @property + def block_type(self): + return self.type.block_type + + @property + def block_shape(self): + return self.type.block_type.shape + + @property + def dtype(self): + return self.type.block_type.element_ty + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, offsets: Sequence[constexpr | tensor], _builder=None) -> tensor: + """Load a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be filled with zeros. + + :note: Offset must be a multiple of 16-bytes + """ + return descriptor_load(self, offsets, "", "", _builder) + + @builtin + def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _builder=None) -> tensor: + """Store a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be ignored. + + :note: Offset must be a multiple of 16-bytes + """ + return descriptor_store(self, value, offsets, _builder) + + +class tensor_descriptor_type(tensor_descriptor_base_type): + + def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type): + self.block_type = block_type + self.shape_type = shape_type + self.strides_type = strides_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + shape = shape.values + strides = strides.values + value = tensor_descriptor(handle, shape, strides, self.block_type) + return value, cursor + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + super()._flatten_ir_types(builder, out) + self.shape_type._flatten_ir_types(builder, out) + self.strides_type._flatten_ir_types(builder, out) + + def __eq__(self, other): + return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type + == other.strides_type) + + +class tensor_descriptor(tensor_descriptor_base): + """A descriptor representing a tensor in global memory. + """ + + def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type): + """Not called by user code.""" + # IR handle + super().__init__(handle, block_type) + # Global shape + self.shape = tuple(shape) + self.strides = tuple(strides) + self.type = tensor_descriptor_type( + block_type, + shape_type=self.shape.type, + strides_type=self.strides.type, + ) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + self.shape._flatten_ir(handles) + self.strides._flatten_ir(handles) diff --git a/python/triton/language/tensor_descriptor.py.none b/python/triton/language/tensor_descriptor.py.none new file mode 100644 index 000000000..e69de29bb diff --git a/python/triton/runtime/autotiling_tuner.py b/python/triton/runtime/autotiling_tuner.py new file mode 100644 index 000000000..357b4509a --- /dev/null +++ b/python/triton/runtime/autotiling_tuner.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import builtins +import os +import time +from typing import Dict, List + +from .autotuner import Autotuner, Config +from .utils import get_byte_per_numel, is_valid_axis_name + + +class AutoTilingTuner(Autotuner): + """ + Automatic generateing candidate tiling configs and evaluating their performance to get the best config. + """ + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=None, + rep=None, + use_cuda_graph=False, + do_bench=None, + auto_profile_dir=None, + split_params=None, + tiling_params=None, + low_dims=None, + dual_reduction=False, + persistent_reduction=False, + ): + """ + :param key: a dict of axis name: argument name, where the change of arguments in value will triger re-generating candidates configs and evaluating. + The axis name should be in {'x','y','z','w','v','t','rx','ry','rz','rw','rv','rt}, where the prefix 'r' means a reduction axis. + Only the axis name in this param should add perfix 'r' if it's a reduction axis. + :type key: Dict[str, str] + :param split_params: a dict of axis name: argument name, the argument is an adjustable parameter in a split axis, such as 'XBLOCK'. + The axis name must be in key's axis names. Do not add prefix 'r' before the axis name. + This param can be empty. Note that the auto tiling feature will be disabled when the split_params and tiling_params are both empty. + The split axis can usually be identified according to `tl.program_id()` expression. + :type split_params: Dict[str, str] + :param tiling_params: a dict of axis name: argument name, the argument is an adjustable parameter in a tiling axis, such as 'XBLOCK_SUB'. + The axis name must be in key's axis names. Do not add prefix 'r' before the axis name. + This param can be empty. Note that the auto tiling feature will be disabled when the split_params and tiling_params are both empty. + The tiling axis can usually be identified according to `tl.arange()` expression. + :type tiling_params: Dict[str, str] + :param low_dims: a list of axis name in which the corresponding axis is low dim aixs. + The axis name must be in key's axis names. Do not add prefix 'r' before the axis name. + :type low_dims: List[str] + :param dual_reduction: performing reduction on more than one axis. + :param persistent_reduction: there is no splitting in reduction axis. + """ + super().__init__( + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook, + post_hook, + prune_configs_by, + warmup, + rep, + use_cuda_graph, + do_bench, + auto_profile_dir, + ) + + if not configs: + self.user_configs = [] + else: + self.user_configs = configs + self.gen_configs = [] # generated configs from TileGenerator + + self.split_params = split_params + self.tiling_params = tiling_params + self.low_dims = low_dims + self.dual_reduction = dual_reduction + self.persistent_reduction = persistent_reduction + + def _gen_tile_configs( + self, kv_dict: Dict[str, int], dtype: torch.dtype + ) -> List[Config]: + from .tile_generator import KernelMeta, TileGenerator + + axis_sizes = {} + for k, v in kv_dict.items(): + if not is_valid_axis_name(k): + continue + if not isinstance(v, int): + raise ValueError( + f"Not supported dim type: {type(v)}, `int` is the only supported type" + ) + axis_sizes[k] = v + + kernel_meta = KernelMeta( + axis_sizes, + self.split_params, + self.tiling_params, + self.low_dims, + dtype, + self.persistent_reduction, + self.dual_reduction, + ) + tile_gen = TileGenerator(kernel_meta=kernel_meta) + tile_gen.descend_split_tiling() + + self.gen_configs.clear() + self.gen_configs = list(tile_gen.configs.values()) + if len(self.gen_configs) == 0: + print( + "[WARNING] The generated candidate tiling configs are empty based on provided parameters!" + ) + + if len(self.gen_configs) == 0 and len(self.user_configs) == 0: + return [ + Config({}) + ] + else: + return self.gen_configs + self.user_configs + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + + # generate key + all_args = {**self.nargs, **kwargs} + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + _kv_dict = {k: _args[v] for k, v in self.keys.items() if v in _args} + key = list(_kv_dict.values()) + + # Currently, we use the dtype with maximum byte length + dtype = None + for _, arg in _args.items(): + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + dtype = ( + arg.dtype + if get_byte_per_numel(arg.dtype) >= get_byte_per_numel(dtype) + else dtype + ) + if dtype is None: + raise NotImplementedError("Not support for non-Tensor inputs") + + key = tuple(key) + if key not in self.cache: + # prune configs + self.configs = self._gen_tile_configs(_kv_dict, dtype) + pruned_configs = self.prune_configs(kwargs) + if len(pruned_configs) > 1: + used_cached_result = False + bench_start = time.time() + timings = { + config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs + } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) + self.configs_timings = timings + config = self.cache[key] + else: + config = pruned_configs[0] + else: + config = self.cache[key] + + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print( + f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};" + ) + + if not used_cached_result and self.auto_profile_dir is not None: + self._profile(*args, config=self.best_config, **kwargs) + if config.pre_hook is not None: + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret diff --git a/python/triton/runtime/autotiling_tuner.py.none b/python/triton/runtime/autotiling_tuner.py.none new file mode 100644 index 000000000..e69de29bb diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 9f494a062..67753e129 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -4,7 +4,7 @@ import os import time import inspect -from typing import Dict +from typing import Dict, List from .jit import KernelInterface from .errors import OutOfResources @@ -28,6 +28,7 @@ def __init__( rep=None, use_cuda_graph=False, do_bench=None, + auto_profile_dir=None, ): """ :param prune_configs_by: a dict of functions that are used to prune configs, fields: @@ -37,8 +38,7 @@ def __init__( """ if not configs: self.configs = [ - Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, - reg_dec_producer=0, reg_inc_consumer=0) + Config({}) ] else: self.configs = configs @@ -100,6 +100,7 @@ def _post_hook(kwargs, exception): self.num_warmups = warmup self.num_reps = rep self.use_cuda_graph = use_cuda_graph + self.auto_profile_dir = auto_profile_dir # If we got explicitly called via the old interface, raise a warning # and proceed with the old behavior. @@ -132,7 +133,7 @@ def _post_hook(kwargs, exception): self.do_bench = do_bench def _bench(self, *args, config, **meta): - from ..compiler.errors import CompileTimeAssertionFailure + from ..compiler.errors import CompileTimeAssertionFailure, MLIRCompilationError # check for conflicts, i.e. meta-parameters both provided # as kwargs and by the autotuner @@ -164,9 +165,44 @@ def kernel_call(): try: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - except (OutOfResources, CompileTimeAssertionFailure): + except (OutOfResources, CompileTimeAssertionFailure, MLIRCompilationError) as e: return [float("inf"), float("inf"), float("inf")] + def _profile(self, *args, config, **meta): + from triton.testing import do_bench_npu + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(full_nargs) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(full_nargs, exception=None) + + do_bench_npu( + kernel_call, prof_dir=self.auto_profile_dir, keep_res=True + ) + def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) used_cached_result = True @@ -197,6 +233,9 @@ def run(self, *args, **kwargs): if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: print(f"Triton autotuning for function {self.base_fn.__name__} finished after " f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + + if not used_cached_result and self.auto_profile_dir is not None: + self._profile(*args, config=self.best_config, **kwargs) if config.pre_hook is not None: full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} config.pre_hook(full_nargs) @@ -260,10 +299,11 @@ class Config: to ptx .maxnreg directive. Not supported on all platforms. :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this function are args. + :ivar bishengir_options: dict of options that pass to bishengir. """ - def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, - reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None): + def __init__(self, kwargs, num_warps=None, num_stages=None, num_ctas=None, num_buffers_warp_spec=None, num_consumer_groups=None, + reg_dec_producer=None, reg_inc_consumer=None, maxnreg=None, pre_hook=None, **bishengir_options): self.kwargs = kwargs self.num_warps = num_warps self.num_ctas = num_ctas @@ -275,6 +315,17 @@ def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_wa self.maxnreg = maxnreg self.pre_hook = pre_hook + + # BiShengIR Options allowed for autotune + self.multibuffer = bishengir_options.get("multibuffer", None) # Compiler Default True + self.unit_flag = bishengir_options.get("unit_flag", None) # Compiler Default False + self.limit_auto_multi_buffer_only_for_local_buffer = bishengir_options.get("limit_auto_multi_buffer_only_for_local_buffer", None) # Compiler Default False + self.limit_auto_multi_buffer_of_local_buffer = bishengir_options.get("limit_auto_multi_buffer_of_local_buffer", None) # Compiler Default no-limit + self.set_workspace_multibuffer = bishengir_options.get("set_workspace_multibuffer", None) # Compiler Default 1 + self.enable_hivm_auto_cv_balance = bishengir_options.get("enable_hivm_auto_cv_balance", None) # Compiler Default True + self.tile_mix_vector_loop = bishengir_options.get("tile_mix_vector_loop", None) # Compiler Default 1 + self.tile_mix_cube_loop = bishengir_options.get("tile_mix_cube_loop", None) # Compiler Default 1 + def all_kwargs(self): return { **self.kwargs, **{ @@ -288,6 +339,16 @@ def all_kwargs(self): ("reg_dec_producer", self.reg_dec_producer), ("reg_inc_consumer", self.reg_inc_consumer), ("maxnreg", self.maxnreg), + + ("multibuffer", self.multibuffer), + ("enable_hivm_auto_cv_balance", self.enable_hivm_auto_cv_balance), + ("unit_flag", self.unit_flag), + ("limit_auto_multi_buffer_only_for_local_buffer", \ + self.limit_auto_multi_buffer_only_for_local_buffer), + ("limit_auto_multi_buffer_of_local_buffer", self.limit_auto_multi_buffer_of_local_buffer), + ("set_workspace_multibuffer", self.set_workspace_multibuffer), + ("tile_mix_vector_loop", self.tile_mix_vector_loop), + ("tile_mix_cube_loop", self.tile_mix_cube_loop), ) if v is not None } } @@ -304,11 +365,22 @@ def __str__(self): res.append(f"reg_dec_producer: {self.reg_dec_producer}") res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") res.append(f"maxnreg: {self.maxnreg}") + + res.append(f"multibuffer: {self.multibuffer}") + res.append(f"enable_hivm_auto_cv_balance: {self.enable_hivm_auto_cv_balance}") + res.append(f"unit_flag: {self.unit_flag}") + res.append(f"limit_auto_multi_buffer_only_for_local_buffer: \ + {self.limit_auto_multi_buffer_only_for_local_buffer}") + res.append(f"limit_auto_multi_buffer_of_local_buffer: {self.limit_auto_multi_buffer_of_local_buffer}") + res.append(f"set_workspace_multibuffer: {self.set_workspace_multibuffer}") + res.append(f"tile_mix_vector_loop: {self.tile_mix_vector_loop}") + res.append(f"tile_mix_cube_loop: {self.tile_mix_cube_loop}") return ", ".join(res) -def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, - warmup=None, rep=None, use_cuda_graph=False, do_bench=None): +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, + pre_hook=None, post_hook=None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None, auto_profile_dir=None, + split_params=None, tiling_params=None, low_dims=None, dual_reduction=False, persistent_reduction=False): """ Decorator for auto-tuning a :code:`triton.jit`'d function. @@ -362,12 +434,23 @@ def kernel(x_ptr, x_size, **META): :type rep: int :param do_bench: a benchmark function to measure the time of each run. :type do_bench: lambda fn, quantiles + :param auto_profile_dir: a directory for storing the profiling result of the best config. + It will automatically profile the best configuration when the value is not None. + :type auto_profile_dir: str """ def decorator(fn): - return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, - post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph) + if split_params or tiling_params: + from .autotiling_tuner import AutoTilingTuner + return AutoTilingTuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir, + split_params=split_params, tiling_params=tiling_params, low_dims=low_dims, + dual_reduction=dual_reduction, persistent_reduction=persistent_reduction) + else: + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir) return decorator diff --git a/python/triton/runtime/autotuner.py.std b/python/triton/runtime/autotuner.py.std new file mode 100644 index 000000000..9f494a062 --- /dev/null +++ b/python/triton/runtime/autotuner.py.std @@ -0,0 +1,408 @@ +from __future__ import annotations + +import builtins +import os +import time +import inspect +from typing import Dict + +from .jit import KernelInterface +from .errors import OutOfResources +from .driver import driver + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=None, + rep=None, + use_cuda_graph=False, + do_bench=None, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [ + Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0) + ] + else: + self.configs = configs + self.keys = key + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_to_zero = [] + if reset_to_zero is not None: + self.reset_to_zero = list(reset_to_zero) + self.restore_value = [] + if restore_value is not None: + self.restore_value = list(restore_value) + + # Hook to reset or restore for required tensors + self.pre_hook = lambda kwargs, reset_only=False: 0 + self.post_hook = lambda kwargs, exception: 0 + self.user_defined_pre_hook = False + self.user_defined_post_hook = False + if pre_hook: + self.pre_hook = pre_hook + self.user_defined_pre_hook = True + elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0): + + def _pre_hook(kwargs, reset_only=False): + for name in self.reset_to_zero: + kwargs[name].zero_() + if not reset_only: + self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + self.user_defined_post_hook = True + elif len(self.restore_value) > 0: + + def _post_hook(kwargs, exception): + for name in self.restore_value: + kwargs[name].copy_(self.restore_copies[name]) + self.restore_copies = {} + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + + self.num_warmups = warmup + self.num_reps = rep + self.use_cuda_graph = use_cuda_graph + + # If we got explicitly called via the old interface, raise a warning + # and proceed with the old behavior. + if warmup is not None or rep is not None or use_cuda_graph: + import warnings + warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " + "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning, + stacklevel=1) + if use_cuda_graph: + from ..testing import do_bench_cudagraph + self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( + kernel_call, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + import triton.testing + self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( + kernel_call, + warmup=warmup if warmup is not None else 25, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + if do_bench is None: + self.do_bench = driver.active.get_benchmarker() + else: + self.do_bench = do_bench + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(full_nargs) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(full_nargs, exception=None) + + try: + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure): + return [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key = [_args[key] for key in self.keys if key in _args] + for _, arg in _args.items(): + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + # prune configs + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + if config.pre_hook is not None: + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.num_buffers_warp_spec = num_buffers_warp_spec + self.num_consumer_groups = num_consumer_groups + self.reg_dec_producer = reg_dec_producer + self.reg_inc_consumer = reg_inc_consumer + self.maxnreg = maxnreg + self.pre_hook = pre_hook + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("num_buffers_warp_spec", self.num_buffers_warp_spec), + ("num_consumer_groups", self.num_consumer_groups), + ("reg_dec_producer", self.reg_dec_producer), + ("reg_inc_consumer", self.reg_inc_consumer), + ("maxnreg", self.maxnreg), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}") + res.append(f"num_consumer_groups: {self.num_consumer_groups}") + res.append(f"reg_dec_producer: {self.reg_dec_producer}") + res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=None, rep=None, use_cuda_graph=False, do_bench=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). + :type warmup: int + :param rep: repetition time (in ms) to pass to benchmarking (deprecated). + :type rep: int + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/python/triton/runtime/code_cache.py b/python/triton/runtime/code_cache.py new file mode 100644 index 000000000..66534c505 --- /dev/null +++ b/python/triton/runtime/code_cache.py @@ -0,0 +1,59 @@ +# Copyright © 2024 BAAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modifications: +# - 2025-06-03: +# - init version: e9c7aa71832eb2f897a49ce787e42d5377404a72 +# + +import functools +import os +import shutil +from pathlib import Path + + +@functools.lru_cache(maxsize=None) # this is the same as functools.cache in Python 3.9+ +def cache_dir_path() -> Path: + """Return the cache directory for generated files in flaggems.""" + _cache_dir = os.environ.get("FLAGGEMS_CACHE_DIR") + if _cache_dir is None: + _cache_dir = Path.home() / ".flaggems" + else: + _cache_dir = Path(_cache_dir) + return _cache_dir + + +def cache_dir() -> Path: + """Return cache directory for generated files in flaggems. Create it if it does not exist.""" + _cache_dir = cache_dir_path() + os.makedirs(_cache_dir, exist_ok=True) + return _cache_dir + + +def code_cache_dir() -> Path: + _code_cache_dir = cache_dir() / "code_cache" + os.makedirs(_code_cache_dir, exist_ok=True) + return _code_cache_dir + + +def config_cache_dir() -> Path: + _config_cache_dir = cache_dir() / "config_cache" + os.makedirs(_config_cache_dir, exist_ok=True) + return _config_cache_dir + + +def clear_cache(): + """Clear the cache directory for code cache.""" + _cache_dir = cache_dir_path() + shutil.rmtree(_cache_dir) diff --git a/python/triton/runtime/code_cache.py.none b/python/triton/runtime/code_cache.py.none new file mode 100644 index 000000000..e69de29bb diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 45178a40b..08422611f 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -6,11 +6,14 @@ import os import re import textwrap +import tokenize from collections import defaultdict from functools import cached_property from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple from ..runtime.driver import driver +from ..backends.ascend.compiler import AscendAttrsDescriptor from types import ModuleType +from io import StringIO TRITON_MODULE = __name__[:-len(".runtime.jit")] @@ -328,7 +331,6 @@ def __getitem__(self, grid) -> T: memorizes the grid. """ return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) - # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) def serialize_specialization_data(name, signature, constants, attrs, options, key): @@ -566,7 +568,8 @@ def run(self, *args, grid, warmup, **kwargs): # parse options from ..compiler import make_backend device = driver.active.get_current_device() - stream = driver.active.get_current_stream(device) + if ('stream' not in kwargs.keys()): + stream = driver.active.get_current_stream(device) target = driver.active.get_current_target() backend = make_backend(target) @@ -590,10 +593,20 @@ def run(self, *args, grid, warmup, **kwargs): # deprecated arguments assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" assert "device" not in kwargs, "device option is deprecated; current device will be used" - assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + # assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" for k in excess_kwargs: if k not in options.__dict__: raise KeyError("Keyword argument %s was specified but unrecognised" % k) + ignor_params = ["debug", "sanitize_overflow", "llvm_version", "kernel_name", \ + "allowed_dot_input_precisions", "multibuffer", "stream"] + not_work_params = [] + for k in kwargs: + if k in ignor_params: + continue + elif k in excess_kwargs: + not_work_params.append(k) + if len(not_work_params) != 0: + print("[WARNING] Please DO NOT tune args {}!".format(not_work_params)) bound_vals = tuple(bound_args.values()) @@ -647,9 +660,16 @@ def run(self, *args, grid, warmup, **kwargs): grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 - + grid_all_size = grid_0 * grid_1 * grid_2 + if os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "0") == "0": + if grid_all_size > 65535: + raise RuntimeError("grid should be less than 65536! You can try \"export TRITON_ALL_BLOCKS_PARALLEL=1\" to avoid this problem.") + if ('stream' in kwargs.keys()): + stream = kwargs["stream"] # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) + # explicitly define run method and load kernel binary + kernel._init_handles() kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) return kernel @@ -729,7 +749,6 @@ def warmup(self, *args, grid, **kwargs): def preload(self, specialization_data): from ..compiler import compile, ASTSource - from triton.backends.compiler import AttrsDescriptor import json import triton.language as tl device = driver.active.get_current_device() @@ -742,7 +761,7 @@ def preload(self, specialization_data): for key, value in deserialized_obj['constants'].items() } signature = dict(deserialized_obj['signature'].items()) - src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + src = ASTSource(self, signature, constants, AscendAttrsDescriptor.from_dict(deserialized_obj['attrs'])) options = { key: tuple(value) if isinstance(value, list) else value for key, value in deserialized_obj['options'].items() @@ -756,10 +775,29 @@ def preload(self, specialization_data): # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. def parse(self): + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = self.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints + + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + tree = ast.parse(self.src) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 assert isinstance(tree.body[0], ast.FunctionDef) + + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints + return tree def __call__(self, *args, **kwargs): diff --git a/python/triton/runtime/jit.py.std b/python/triton/runtime/jit.py.std new file mode 100644 index 000000000..45178a40b --- /dev/null +++ b/python/triton/runtime/jit.py.std @@ -0,0 +1,951 @@ +from __future__ import annotations, division +import ast +import hashlib +import inspect +import itertools +import os +import re +import textwrap +from collections import defaultdict +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple +from ..runtime.driver import driver +from types import ModuleType + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def _is_triton_builtin(self, node, func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + def _update_hash(self, func): + if isinstance(func, JITFunction): + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & func.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = func.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + self.used_global_vals.update(func.used_global_vals) + # update hash + func_key = func.cache_key + func_key += str(getattr(func, "noinline", False)) + self.hasher.update(func_key.encode("utf-8")) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + val = self.globals.get(node.id, None) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) is not ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins): + self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + + self._update_hash(val) + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + ret = getattr(lhs, node.attr) + self._update_hash(ret) + return ret + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + if isinstance(ty, type): + return ty.__name__ + elif isinstance(ty, str): + return ty + return repr(ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool, + do_not_specialize_on_alignment: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self): + annotation = self.annotation + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +def compute_spec_key(v, align): + + if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): + return "D" + elif isinstance(v, int): + # bool is a subclass of int, so we don't check explicitly above. + if align and (v % 16 == 0): + return "D" + elif v == 1: + return "1" + return "N" + + +dtype2str = {} + + +def mangle_type(arg, is_const=False): + + if arg is None: + return "none" + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif hasattr(arg, "tma_desc_cpu_ptr"): + return "nvTmaDesc" + else: + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] + dtype2str[dsk] = res + return res + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': + options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams, backend): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + + assert len(sig.parameters) == len(kparams) + + # Create the function argument list and the dict entries for the return statement + func_args = [] + dict_entries = [] + constexpr_vals = [] + non_constexpr_vals = [] + signature_types = [] + specialisations = [] + + for ((name, sp), kp) in zip(sig.parameters.items(), kparams): + if sp.default is inspect.Parameter.empty: + func_args.append(name) + dict_entries.append(f"'{name}': {name}") + else: + func_args.append(f"{name}=default_{name}") + dict_entries.append(f"'{name}': {name}") + if kp.is_constexpr: + constexpr_vals.append(name) + else: + non_constexpr_vals.append(name) + if not kp.do_not_specialize: + if not kp.do_not_specialize_on_alignment: + specialisations.append('compute_spec_key(%s, align=True)' % name) + else: + specialisations.append('compute_spec_key(%s, align=False)' % name) + if kp.annotation_type: + signature_types.append('"%s"' % kp.annotation_type) + else: + signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False')) + + cache_key = ''.join([x + ', ' for x in signature_types + specialisations]) + constexpr_vals = ''.join([x + ', ' for x in constexpr_vals]) + non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals]) + + func_args.append('**excess_kwargs') + + # Join all arguments into a function definition string + args_str = ', '.join(func_args) + dict_str = ', '.join(dict_entries) + func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % ( + args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals) + + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace['mangle_type'] = mangle_type + func_namespace['compute_spec_key'] = backend.compute_spec_key + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +type_canonicalisation_dict = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +class JITFunction(KernelInterface[T]): + # Hook for inspecting compiled functions and modules + cache_hook = None + # Hook to signal that a kernel is done compiling and inspect compiled function. + # cache_hook will always be called before compilation and compiled_hook after. + compiled_hook = None + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif arg is None: + return None + else: + raise TypeError(f"Unsupported type {type(arg)} for {arg}") + + @staticmethod + def _type_of(key, is_const=False): + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + elif isinstance(key, str): + return key + + dtype_str = str(key).split(".")[-1] + dtype_str = type_canonicalisation_dict[dtype_str] + const_str = "*k" if is_const else "*" + return const_str + dtype_str + + def _make_constants(self, constexpr_key): + constants = dict(zip(self.constexprs, constexpr_key)) + return constants + + def _call_hook( + self, + key, + signature, + device, + constants, + options, + configs, + is_warmup, + before, + ): + hook = JITFunction.cache_hook if before else JITFunction.compiled_hook + if hook is None: + return False + + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" + + class JitFunctionInfo: + + def __init__(self, module, name, jit_function): + self.module = module + self.name = name + self.jit_function = jit_function + pass + + specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + 'is_warmup': is_warmup, + } + + return hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=is_warmup, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self, backend): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + self.make_backend = make_backend + self.binder = create_function_from_signature(self.signature, self.params, backend) + self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] + self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] + self.specialised_indices = [ + i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr) + ] + + def run(self, *args, grid, warmup, **kwargs): + kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" + + # parse options + from ..compiler import make_backend + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + target = driver.active.get_current_target() + backend = make_backend(target) + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + if self.binder is None: + self.create_binder(backend) + + bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) + + # compute cache key + key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) + kernel = self.cache[device].get(key, None) + + if kernel is None: + # Kernel is not cached; we have to compile. + options = backend.parse_options(kwargs) + + # deprecated arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: + if k not in options.__dict__: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + + bound_vals = tuple(bound_args.values()) + + # `None` is nullptr. Implicitly convert to *i8. This needs to be + # done here rather than when we build the signature as otherwise + # the kernel cache key could not distinguish between byte pointers + # and None arguments, resulting in a downstream mismatch: + sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigvals = sig_and_spec[:len(sigkeys)] + signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + + configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) + constant_params = configs[0].get_constants() + constants = { + p.name: v + for (v, p) in zip(bound_vals, self.params) + if p.is_constexpr or (p.num in constant_params) or v is None + } + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + + if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): + return None + # compile the kernel + src = self.ASTSource(self, signature, constants, configs[0]) + kernel = self.compile( + src, + target=target, + options=options.__dict__, + ) + self.cache[device][key] = kernel + self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) + + # Check that used global values have not changed. + not_present = object() + for (name, _), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) + return kernel + + def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None, + noinline=None, repr=None, launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + self.starting_line_number = inspect.getsourcelines(fn)[1] + self.repr = lambda _: fn.__name__ if repr is None else repr(_) + self.launch_metadata = launch_metadata + + self.binder = None + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = i in do_not_specialize or param.name in do_not_specialize + dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment + self.params.append(KernelParam(i, param, dns, dns_oa)) + + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():] + # cache of just-in-time compiled kernels + self.cache = defaultdict(dict) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler import compile, ASTSource + from triton.backends.compiler import AttrsDescriptor + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self.fn.__name__: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in deserialized_obj['constants'].items() + } + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.cache[device][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == "src": + self.hash = None + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from .interpreter import InterpretedFunction + return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, + noinline=noinline, repr=repr, launch_metadata=launch_metadata) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + @staticmethod + def ptr_range(): + return 0 # optimistically assumes 32 bit pointer range + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, i): + return self.base.stride(i) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def clone(self): + return TensorWrapper(self.base.clone(), self.dtype) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") + + +def get_jit_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line diff --git a/python/triton/runtime/libentry.py b/python/triton/runtime/libentry.py new file mode 100644 index 000000000..fcf41fcfb --- /dev/null +++ b/python/triton/runtime/libentry.py @@ -0,0 +1,297 @@ +# Copyright © 2024 BAAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modifications: +# - 2025-06-03: +# - init version: e9c7aa71832eb2f897a49ce787e42d5377404a72 +# - adapt torch_device_fn to ascend +# + +import inspect +import sqlite3 +import threading +import weakref +from typing import Dict +import ast + +import triton +from triton._C import libentry_ascend +import torch +import torch_npu +torch_device_fn = torch.npu + +from .code_cache import config_cache_dir + +DEVICE_COUNT = torch_device_fn.device_count() +major_version = int(triton.__version__.split(".")[0]) + + +def quote_identifier(name: str) -> str: + if not name: + raise ValueError("empty identifier") + allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_") + if not (name[0].isalpha() or name[0] == "_"): + raise ValueError("identifier must start with letter or _") + if not all(ch in allowed for ch in name): + raise ValueError("identifier contains illegal char") + return '"' + name.replace('"', '""') + '"' + + +class LibTuner(triton.runtime.Autotuner): + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + use_cuda_graph=False, + ): + if major_version == 2: + super().__init__( + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by, + warmup, + rep, + ) + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + else: + super().__init__( + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook, + post_hook, + prune_configs_by, + warmup, + rep, + use_cuda_graph, + ) + self.__name__ = self.base_fn.__name__ + self.table_name = quote_identifier(self.__name__) + self.cache_path = config_cache_dir() / "TunedConfig.db" + self.preload() + weakref.finalize(self, self.store) + + def preload(self): + connect = sqlite3.connect(self.cache_path) + c = connect.cursor() + c.execute( + f"CREATE TABLE IF NOT EXISTS {self.table_name} (key TEXT PRIMARY KEY, config TEXT)" + ) + cursor = c.execute(f"SELECT key, config from {self.table_name}") + + for row in cursor: + key_str, config_str = row + key = [ast.literal_eval(k) for k in key_str[1:-1].split(", ")] + + cfg_ls = [item.split(": ") for item in config_str.split(", ")] + config = triton.Config({}) + attrs = -5 if major_version == 2 else -4 + for k, v in cfg_ls[:attrs]: + config.kwargs[k] = ast.literal_eval(v) + config.num_warps = ast.literal_eval(cfg_ls[attrs][1]) + config.num_ctas = ast.literal_eval(cfg_ls[attrs + 1][1]) + config.num_stages = ast.literal_eval(cfg_ls[attrs + 2][1]) + if major_version == 2: + config.enable_warp_specialization = ast.literal_eval(cfg_ls[attrs + 3][1]) + config.enable_persistent = ast.literal_eval(cfg_ls[attrs + 4][1]) + else: + config.maxnreg = ast.literal_eval(cfg_ls[attrs + 3][1]) + + self.cache[tuple(key)] = config + + connect.close() + self.volumn = len(self.cache) + + def store(self): + if len(self.cache) == self.volumn: + return + connect = sqlite3.connect(self.cache_path) + c = connect.cursor() + c.execute( + f"CREATE TABLE IF NOT EXISTS {self.table_name} (key TEXT PRIMARY KEY, config TEXT)" + ) + for key, config in self.cache.items(): + c.execute( + f"INSERT OR IGNORE INTO {self.table_name} (key, config) VALUES (?, ?)", + (str(key), config.__str__()), + ) + + connect.commit() + connect.close() + + +def libtuner( + configs, + key, + prune_configs_by=None, + reset_to_zero=None, + restore_value=None, + pre_hook=None, + post_hook=None, + warmup=25, + rep=100, + use_cuda_graph=False, +): + """ + Decorator for triton library autotuner. + """ + + def decorator(fn): + return LibTuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=pre_hook, + post_hook=post_hook, + prune_configs_by=prune_configs_by, + warmup=warmup, + rep=rep, + use_cuda_graph=use_cuda_graph, + ) + + return decorator + + +class LibEntry(triton.KernelInterface): + def __init__( + self, + fn, + ): + self.fn = fn + self.arg_names = fn.arg_names + self.divisibility = 16 + self.kernel_cache = tuple(dict() for _ in range(DEVICE_COUNT)) + + while not isinstance(fn, triton.runtime.JITFunction): + fn = fn.fn + self.jit_function: triton.runtime.JITFunction = fn + self.specialize_indices = [ + p.num + for p in self.jit_function.params + if not p.is_constexpr and not p.do_not_specialize + ] + self.do_not_specialize_indices = [ + p.num + for p in self.jit_function.params + if not p.is_constexpr and p.do_not_specialize + ] + self.lock = threading.Lock() + + def run(self, *args, **kwargs): + grid = kwargs["grid"] + + # collect all the arguments + spec_args = [] # specialize arguments + dns_args = [] # do not specialize arguments + const_args = [] # constexpr arguments + k_args = [] # kernel arguments + arg_processor = libentry_ascend.ArgProcessor(self.divisibility) + arg_processor.classify_arguments( + list(args), + kwargs, + self.jit_function.params, + set(self.specialize_indices), + set(self.do_not_specialize_indices) + ) + + k_args = arg_processor.get_k_args() + + entry_key = arg_processor.generate_key() + device = torch_device_fn.current_device() + cache = self.kernel_cache[device] + while entry_key not in cache: + # NOTE: we serialize the first run of a jit function regardless of which device to run on + # because Triton runtime is currently not threadsafe. + with self.lock: + if entry_key in cache: + break + kernel = self.fn.run(*args, **kwargs) + fn = self.fn + # collect constexpr arguments for grid computation + constexprs = {} + while not isinstance(fn, triton.runtime.JITFunction): + if isinstance(fn, triton.runtime.Autotuner): + config = fn.best_config + constexprs["num_warps"] = config.num_warps + constexprs["num_stages"] = config.num_stages + constexprs["num_ctas"] = config.num_ctas + constexprs = {**constexprs, **config.kwargs} + elif isinstance(fn, triton.runtime.Heuristics): + for v, heur in fn.values.items(): + constexprs[v] = heur( + { + **dict(zip(fn.arg_names, args)), + **kwargs, + **constexprs, + } + ) + else: + raise RuntimeError("Invalid Runtime Function") + fn = fn.fn + for p in self.jit_function.params: + if ( + p.is_constexpr + and p.name not in constexprs + and (p.default is not inspect._empty) + ): + constexprs[p.name] = p.default + cache[entry_key] = (kernel, constexprs) + return kernel, constexprs + + kernel, constexprs = cache[entry_key] + + if callable(grid): + # collect all arguments to the grid fn,ie: + # 1. args, + # 2. kwargs, + # 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics + # when kwargs & captured args conflict, captured args have higher priority + meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs} + grid = grid(meta) + grid = grid + (1, 1) + + kernel[grid[0:3]](*k_args) + return kernel, constexprs + + +def libentry(): + """ + Decorator for triton library entries. + """ + + def decorator(fn): + return LibEntry(fn) + + return decorator \ No newline at end of file diff --git a/python/triton/runtime/libentry.py.none b/python/triton/runtime/libentry.py.none new file mode 100644 index 000000000..e69de29bb diff --git a/python/triton/runtime/tile_generator.py b/python/triton/runtime/tile_generator.py new file mode 100644 index 000000000..453cce7e2 --- /dev/null +++ b/python/triton/runtime/tile_generator.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import functools +from dataclasses import dataclass +from typing import ( + Dict, + List, + Tuple, +) + +from .utils import ( + get_byte_per_numel, + next_power_of_2, + num_vector_core, +) +from .autotuner import Config + + +@dataclass +class AxisInfo: + name: str + index: int + length: int + + prefix: str = "" + split_name: str = "" + tiling_name: str = "" + is_split_axis: bool = False + is_tiling_axis: bool = False + + @property + def is_reduction(self): + return self.prefix == "r" + + +class KernelMeta: + def __init__( + self, + axis_sizes: Dict[str, int], + split_params: Dict[str, str], + tiling_params: Dict[str, str], + low_dims: List[str], + dtype: torch.dtype, + persistent_reduction: bool, + dual_reduction: bool, + ): + self._validate_axis(axis_sizes, split_params, tiling_params, low_dims) + + axis_dict = {} + idx = 0 + for name, length in axis_sizes.items(): + prefix = "" + if name.startswith("r"): + prefix = "r" + name = name[1:] + + is_split_axis = name in split_params + is_tiling_axis = name in tiling_params + split_name = "" if name not in split_params else split_params[name] + tiling_name = "" if name not in tiling_params else tiling_params[name] + + axis_dict[name] = AxisInfo( + name=name, + index=idx, + length=length, + prefix=prefix, + split_name=split_name, + tiling_name=tiling_name, + is_split_axis=is_split_axis, + is_tiling_axis=is_tiling_axis, + ) + idx += 1 + + self.axis_info = list(axis_dict.values()) + self.split_axis = [x for x in axis_dict.values() if x.is_split_axis] + self.tiling_axis = [x for x in axis_dict.values() if x.is_tiling_axis] + self.low_dims_axis = [x for x in axis_dict.values() if x.name in low_dims] + self.dtype = dtype + self.persistent_reduction = persistent_reduction + self.dual_reduction = dual_reduction + + @classmethod + def _validate_axis( + cls, + axis_sizes: Dict[str, int], + split_params: Dict[str, str], + tiling_params: Dict[str, str], + low_dims: List[str], + ) -> None: + for axis_name in axis_sizes.keys(): + if axis_name.startswith("r") and len(axis_name) == 1: + raise ValueError("The name of a reduction axis is empty!") + + def check_keys(params: List[str], context="parameter"): + for k in params: + if k not in axis_sizes and ("r" + k) not in axis_sizes: + raise KeyError( + f"{context} '{k}' not found in known axes: {axis_sizes.keys()}" + ) + + check_keys(split_params.keys(), "split axis") + check_keys(tiling_params.keys(), "tiling axis") + check_keys(low_dims, "low dim axis") + + +@dataclass +class BlockInfo: + block_name: str # e.g., XBLOCK + sub_block_name: str # e.g., XBLOCK_SUB + block_size: int + sub_block_size: int + + +""" +Generate possible candidate tiling configs for benchmarking +""" +class TileGenerator: + num_warps = 1 + num_stages = 1 + stop_bytes = 1024 + max_tile_bytes = 16384 * 4 + + def __init__(self, kernel_meta: KernelMeta): + self.kernel_meta = kernel_meta + self.persistent_reduction = self.kernel_meta.persistent_reduction + self.dual_reduction = self.kernel_meta.dual_reduction + + self.blocks = self.init_blocks_info(kernel_meta) + self.split_magic_nums = [] + for axis in self.kernel_meta.axis_info: + if axis.is_split_axis: + self.split_magic_nums.append( + (axis.length + num_vector_core - 1) // num_vector_core + ) + else: + self.split_magic_nums.append(-1) + + self.candidates_blk_sizes: List[Tuple[int, ...]] = [] + self.configs = {} + self.dtype_bytes = get_byte_per_numel(kernel_meta.dtype) + self.stop_numel = self.stop_bytes // self.dtype_bytes + self.max_tile_numel = self.max_tile_bytes // self.dtype_bytes + + @classmethod + def init_blocks_info(cls, kernel_meta: KernelMeta) -> List[BlockInfo]: + blocks = [] + for axis in kernel_meta.axis_info: + block_name = axis.split_name + sub_block_name = axis.tiling_name + block_size = axis.length + sub_block_size = block_size + blocks.append( + BlockInfo(block_name, sub_block_name, block_size, sub_block_size) + ) + + return blocks + + @classmethod + def get_key_from_dict(cls, kwargs: Dict[str, int]): + return tuple(sorted(kwargs.items())) + + def valid_tile_numel(self, tile_numel: int) -> bool: + return tile_numel <= self.max_tile_numel + + def calculate_tile_numel(self) -> int: + tile_numel = 1 + for axis in self.kernel_meta.axis_info: + if axis.is_tiling_axis: + tile_numel *= self.blocks[axis.index].sub_block_size + else: + # this axis's tiling size is the same as block size + tile_numel *= self.blocks[axis.index].block_size + + return tile_numel + + def add_to_configs(self, cand_sizes) -> None: + kwargs = {} + for axis in self.kernel_meta.axis_info: + if not (axis.is_split_axis or axis.is_tiling_axis): + continue + + block_info = self.blocks[axis.index] + if axis.is_split_axis: + kwargs[block_info.block_name] = cand_sizes[axis.index] + if axis.is_tiling_axis: + kwargs[block_info.sub_block_name] = next_power_of_2( + block_info.sub_block_size + ) + + tile_numel = 1 + for axis in self.kernel_meta.axis_info: + if not (axis.is_split_axis or axis.is_tiling_axis): + tile_numel *= self.blocks[axis.index].block_size + continue + + if axis.is_tiling_axis: + tile_numel *= kwargs.get(self.blocks[axis.index].sub_block_name, 1) + else: + tile_numel *= kwargs.get(self.blocks[axis.index].block_name, 1) + + key = self.get_key_from_dict(kwargs) + if self.valid_tile_numel(tile_numel) and key not in self.configs: + self.configs[key] = Config( + kwargs, num_warps=self.num_warps, num_stages=self.num_stages + ) + + def descend_one_axis(self, axis_idx: int, is_split=False) -> bool: + def calc_total_programs(): + grids = [] + for axis in self.kernel_meta.split_axis: + block_size = self.blocks[axis.index].block_size + programs = (axis.length + block_size - 1) // block_size + grids.append(programs) + + return functools.reduce(lambda x, y: x * y, grids) if grids else 1 + + reached_stop_numel = False + slow_descend_split = False + magic_descend_split = False + if not is_split and len(self.candidates_blk_sizes) == 0: + self.candidates_blk_sizes.append( + tuple([x.block_size for x in self.blocks]) + ) + + axis = self.kernel_meta.axis_info[axis_idx] + while True: + for cand_sizes in self.candidates_blk_sizes: + self.add_to_configs(cand_sizes) + + # tile numel reached threshold + tile_numel = self.calculate_tile_numel() + if tile_numel <= self.stop_numel: + self.add_to_configs([x.block_size for x in self.blocks]) + reached_stop_numel = True + break + + numel = ( + self.blocks[axis_idx].block_size + if is_split + else self.blocks[axis_idx].sub_block_size + ) + if numel == 1: + self.add_to_configs([x.block_size for x in self.blocks]) + break + + if is_split: + if self.persistent_reduction and axis.is_reduction: + reached_stop_numel = True + break + total_programs = calc_total_programs() + if total_programs > num_vector_core: + break + if total_programs > num_vector_core // 2 or self.dual_reduction: + if len(self.candidates_blk_sizes) > 2: + self.candidates_blk_sizes.pop(0) + self.candidates_blk_sizes.append( + tuple([x.block_size for x in self.blocks]) + ) + + if ( + not magic_descend_split + and (numel // 2) <= self.split_magic_nums[axis_idx] + ): + self.blocks[axis_idx].block_size = self.split_magic_nums[axis_idx] + self.blocks[axis_idx].sub_block_size = self.blocks[axis_idx].block_size + magic_descend_split = True + continue + + self.blocks[axis_idx].block_size = numel // 2 + self.blocks[axis_idx].sub_block_size = self.blocks[axis_idx].block_size + if calc_total_programs() > num_vector_core: + slow_descend_split = True + step = numel // 4 if numel // 4 > 1 else 1 + self.blocks[axis_idx].block_size = ( + numel // 2 if not slow_descend_split else numel - step + ) + self.blocks[axis_idx].sub_block_size = self.blocks[axis_idx].block_size + else: + self.blocks[axis_idx].sub_block_size = next_power_of_2(numel // 2) + + return reached_stop_numel + + def descend_all_low_dims(self) -> None: + low_dim_numels = [self.blocks[x.index].sub_block_size for x in self.kernel_meta.low_dims_axis] + if not low_dim_numels: + return + + def descend_all_axis(min_numel): + for axis in self.kernel_meta.low_dims_axis: + if axis.is_reduction and self.persistent_reduction: + continue + numel = self.blocks[axis.index].sub_block_size + if numel == 1: + continue + if min_numel > 1 and abs(numel - min_numel) / min_numel < 0.2: + continue + self.blocks[axis.index].sub_block_size = next_power_of_2(numel // 2) + + if len(self.candidates_blk_sizes) == 0: + # means there is no split axis and tiling_not_low_dim axis + # so we need to init the candidates_blk_sizes + self.candidates_blk_sizes.append( + tuple([x.block_size for x in self.blocks]) + ) + count = 0 + tile_numel = self.calculate_tile_numel() + while tile_numel > self.stop_numel and count < 100: + count += 1 + tile_numel = self.calculate_tile_numel() + for cand_sizes in self.candidates_blk_sizes: + self.add_to_configs(cand_sizes) + min_numel = min(low_dim_numels) + descend_all_axis(min_numel) + new_tile_numel = self.calculate_tile_numel() + if tile_numel == new_tile_numel: + descend_all_axis(0) + + def descend_split_tiling(self): + + tiling_not_low_dims = [ + x + for x in self.kernel_meta.tiling_axis + if x not in self.kernel_meta.low_dims_axis + ] + + def descend_split_axis(): + for axis in self.kernel_meta.split_axis: + if self.descend_one_axis(axis.index, is_split=True): + return True + return self.calculate_tile_numel() <= self.stop_numel + + def descend_tiling_not_low_dims(): + for axis in tiling_not_low_dims: + if axis.is_reduction and self.persistent_reduction: + continue + if self.descend_one_axis(axis.index): + return True + return self.calculate_tile_numel() <= self.stop_numel + + def descend_low_dims(): + for axis in self.kernel_meta.tiling_axis: + if axis.is_reduction and self.persistent_reduction: + continue + if axis in tiling_not_low_dims: + continue + if self.descend_one_axis(axis.index): + return True + return self.calculate_tile_numel() <= self.stop_numel + + while True: + # descend split axis + if descend_split_axis(): + break + + if len(self.candidates_blk_sizes) > 0: + candi_blk = self.candidates_blk_sizes[0] + for i, blk_size in enumerate(candi_blk): + self.blocks[i].sub_block_size = blk_size + + # descend tiling but not low dims + if descend_tiling_not_low_dims(): + break + + # descend low dims + self.descend_all_low_dims() + break diff --git a/python/triton/runtime/tile_generator.py.none b/python/triton/runtime/tile_generator.py.none new file mode 100644 index 000000000..e69de29bb diff --git a/python/triton/runtime/utils.py b/python/triton/runtime/utils.py new file mode 100644 index 000000000..653fff403 --- /dev/null +++ b/python/triton/runtime/utils.py @@ -0,0 +1,69 @@ +import torch + +from .driver import driver + +# npu hardware params +target = driver.active.get_current_target() +device = driver.active.get_current_device() +prop = driver.active.utils.get_device_properties(device) + +num_cube_core = prop["num_aicore"] +num_vector_core = prop["num_aicore"] + +if "Ascend910B" in target.arch: + num_vector_core = num_cube_core * 2 + +# wrapper npu 32 bytes align, get and pass unalign info to triton meta +# then autotune choose tiling param and send them to bishengIR +byte_per_numel = { + torch.float32: 4, # torch.float32 or torch.float + torch.float64: 8, # torch.float64 or torch.double + torch.float16: 2, # torch.float16 or torch.half + torch.bfloat16: 2, # torch.bfloat16 + torch.int32: 4, # torch.int32 or torch.int + torch.int64: 8, # torch.int64 or torch.long + torch.int16: 2, # torch.int16 or torch.short + torch.int8: 1, # torch.int8 + torch.uint8: 1, # torch.uint8 + torch.bool: 1, # torch.bool + torch.complex32: 4, # torch.complex32 (not yet available in PyTorch as of the latest stable release) + torch.complex64: 8, # torch.complex64 + torch.complex128: 16, # torch.complex128 +} + +valid_axis_names = { + "x", + "y", + "z", + "w", + "v", + "t", + "rx", + "ry", + "rz", + "rw", + "rv", + "rt", +} + + +def get_byte_per_numel(dtype: torch.dtype) -> int: + return 1 if dtype is None else byte_per_numel[dtype] + + +def is_valid_axis_name(name: str) -> bool: + return name in valid_axis_names + + +# move to an appropriate place, currently duplicated with triton.__init__.py +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/python/triton/runtime/utils.py.none b/python/triton/runtime/utils.py.none new file mode 100644 index 000000000..e69de29bb diff --git a/python/triton/testing.py b/python/triton/testing.py index 71cb8ab1e..b929ef22c 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,6 +1,8 @@ import functools import os import subprocess +import multiprocessing +import os import sys from contextlib import contextmanager from typing import Any, Dict, List @@ -112,6 +114,11 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m assert return_mode in ["min", "max", "mean", "median", "all"] import torch + enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() == 'npu' + if torch.npu.is_available() and enable_bench_npu: + avg_time = do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) + return _summarize_statistics(torch.tensor([avg_time], dtype=torch.float), quantiles, return_mode) + di = runtime.driver.active.get_device_interface() fn() @@ -157,6 +164,99 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) +def collect_files(base_dir): + import pandas as pd + for root, dirs, files in os.walk(base_dir): + for file in files: + if file != 'op_statistic.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] + if not triton_rows.empty: + return triton_rows['Avg Time(us)'].values[0] + return float('inf') + return float('inf') + + +def collect_single(base_dir: str, key: str = None) -> float: + if not os.path.exists(base_dir): + return float('inf') + + import pandas as pd + for root, _, files in os.walk(base_dir): + for file in files: + if file != 'op_statistic.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + if key is not None: + key_rows = df[df['OP Type'].str.startswith(key, na=False)] + if not key_rows.empty: + return key_rows['Avg Time(us)'].values[0] + return float('inf') + else: + # default: read the first row except header + return df.loc[0, 'Avg Time(us)'] + + return float('inf') + + +def do_bench_npu(fn, warmup=5, active=30, prof_dir=None, keep_res=False): + import torch + import torch_npu + from datetime import datetime, timezone + + # warmup kernel + fn() + torch.npu.synchronize() + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + skip_first = 1 + wait = 0 + repeat = 1 + total = skip_first + (wait + warmup + active) * repeat + + if prof_dir is not None: + torch_path = prof_dir + else: + process = multiprocessing.current_process() + pid = process.pid + process_name = process.name + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") + base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") + torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config, + ) as prof: + for _ in range(total): + fn() + prof.step() + torch.npu.synchronize() + + time = collect_single(torch_path) + + if not keep_res: + import shutil + if os.path.exists(torch_path): + shutil.rmtree(torch_path) + + return time def assert_close(x, y, atol=None, rtol=None, err_msg=''): """ @@ -331,7 +431,6 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b ax.legend() ax.set_xlabel(bench.xlabel or first_x) ax.set_ylabel(bench.ylabel) - # ax.set_title(bench.plot_name) ax.set_xscale("log" if bench.x_log else "linear") ax.set_yscale("log" if bench.y_log else "linear") if show_plots: @@ -509,3 +608,206 @@ def get_max_simd_tflops(dtype, clock_rate, device=None): raise RuntimeError("dtype not supported") tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops + +# Patch the triton language API here because triton's __init__.py +# import testing in the last stages. +from .language.tensor_descriptor import ( + tensor_descriptor, + tensor_descriptor_type, +) + +from .language.core_ext import ( + dot, + cast, + gather, + get_element, + insert_slice, + extract_slice, + trans, + __add__, + __radd__, + __sub__, + __rsub__, + __mul__, + __rmul__, + __lshift__, + __rshift__, + parallel, + compile_hint, + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + multibuffer, + sync_block_all, + sync_block_set, + sync_block_wait, + dtype_to_ir, + sort +) +from .language.standard_ext import flip, sigmoid, softmax, isfinited, finitef, rint, atan2 +from .language.math_ext import ( + umulhi, + exp, + exp2, + log, + log2, + cos, + sin, + sqrt, + sqrt_rn, + rsqrt, + div_rn, + erf, + tanh, + floor, + ceil, + _check_dtype, + fma, +) +from .language.semantic_ext import ( + arange, + floordiv, + atom_red_typechecking_impl, + atomic_cas, + atomic_max, + atomic_min, + _load_legacy, + maximum, + minimum, + mod, + invert, + logical_and, + logical_or, + not_, + and_, + or_, + xor_, + minus, + dot_scaled, +) +from . import language + +language.cast = cast +language.dot = dot +language.flip = flip +language.sigmoid = sigmoid +language.softmax = softmax +language.gather = gather +language.insert_slice = insert_slice +language.extract_slice = extract_slice +language.get_element = get_element +language.tensor.__add__ = __add__ +language.tensor.__radd__ = __radd__ +language.tensor.__sub__ = __sub__ +language.tensor.__rsub__ = __rsub__ +language.tensor.__mul__ = __mul__ +language.tensor.__rmul__ = __rmul__ +language.tensor.__lshift__ = __lshift__ +language.tensor.__rshift__ = __rshift__ +language.trans = trans +language.parallel = parallel +language.compile_hint = compile_hint +language.sort = sort +language.multibuffer = multibuffer +language.sync_block_all = sync_block_all +language.sync_block_set = sync_block_set +language.sync_block_wait = sync_block_wait +language.make_tensor_descriptor = make_tensor_descriptor +language.tensor_descriptor = tensor_descriptor +language.tensor_descriptor_type = tensor_descriptor_type +language.load_tensor_descriptor = load_tensor_descriptor +language.store_tensor_descriptor = store_tensor_descriptor + +language.semantic.arange = arange +language.semantic.floordiv = floordiv +language.semantic.atom_red_typechecking_impl = atom_red_typechecking_impl +language.semantic.atomic_cas = atomic_cas +language.semantic.atomic_max = atomic_max +language.semantic.atomic_min = atomic_min +language.semantic._load_legacy = _load_legacy +language.semantic.maximum = maximum +language.semantic.minimum = minimum +language.semantic.invert = invert +language.semantic.logical_and = logical_and +language.semantic.logical_or = logical_or +language.semantic.mod = mod +language.semantic.not_ = not_ +language.semantic.and_ = and_ +language.semantic.or_ = or_ +language.semantic.xor_ = xor_ +language.semantic.minus = minus +language.semantic.dot_scaled = dot_scaled + +language.umulhi = umulhi +language.exp = exp +language.exp2 = exp2 +language.log = log +language.log2 = log2 +language.cos = cos +language.sin = sin +language.sqrt = sqrt +language.sqrt_rn = sqrt_rn +language.rsqrt = rsqrt +language.div_rn = div_rn +language.erf = erf +language.tanh = tanh +language.floor = floor +language.ceil = ceil +language.core.dtype.to_ir = dtype_to_ir +language.fma = fma +language.math.umulhi = umulhi +language.math.exp = exp +language.math.exp2 = exp2 +language.math.log = log +language.math.log2 = log2 +language.math.cos = cos +language.math.sin = sin +language.math.sqrt = sqrt +language.math.sqrt_rn = sqrt_rn +language.math.rsqrt = rsqrt +language.math.div_rn = div_rn +language.math.erf = erf +language.math.tanh = tanh +language.math.floor = floor +language.math.ceil = ceil +language.math._check_dtype = _check_dtype +language.math.fma = fma +language.math.isnan = language.extra.ascend.libdevice.isnan +language.math.isinf = language.extra.ascend.libdevice.isinf +language.math.reciprocal = language.extra.ascend.libdevice.reciprocal +language.math.log1p = language.extra.ascend.libdevice.log1p +language.math.relu = language.extra.ascend.libdevice.relu +language.math.tan = language.extra.ascend.libdevice.tan +language.math.atan = language.extra.ascend.libdevice.atan +language.math.tanh = language.extra.ascend.libdevice.tanh +language.math.ilogb = language.extra.ascend.libdevice.ilogb +language.math.ldexp = language.extra.ascend.libdevice.ldexp +language.math.pow = language.extra.ascend.libdevice.pow +language.math.flip = language.extra.ascend.libdevice.flip +language.math.atan2 = language.extra.ascend.libdevice.atan2 +language.math.div_rz = language.extra.ascend.libdevice.div_rz +language.math.fmod = language.extra.ascend.libdevice.fmod +language.math.trunc = language.extra.ascend.libdevice.trunc +language.math.round = language.extra.ascend.libdevice.round +language.math.finitef = finitef +language.math.isfinited = isfinited +language.math.rint = rint +language.math.atan2 = atan2 +language.extra.ascend.libdevice.umulhi = language.math.umulhi +language.extra.ascend.libdevice.exp = language.math.exp +language.extra.ascend.libdevice.exp2 = language.math.exp2 +language.extra.ascend.libdevice.log = language.math.log +language.extra.ascend.libdevice.log2 = language.math.log2 +language.extra.ascend.libdevice.cos = language.math.cos +language.extra.ascend.libdevice.sin = language.math.sin +language.extra.ascend.libdevice.sqrt = language.math.sqrt +language.extra.ascend.libdevice.sqrt_rn = language.math.sqrt_rn +language.extra.ascend.libdevice.rsqrt = language.math.rsqrt +language.extra.ascend.libdevice.div_rn = language.math.div_rn +language.extra.ascend.libdevice.erf = language.math.erf +language.extra.ascend.libdevice.tanh = language.math.tanh +language.extra.ascend.libdevice.floor = language.math.floor +language.extra.ascend.libdevice.ceil = language.math.ceil +language.extra.ascend.libdevice.fdiv = language.math.fdiv +language.extra.ascend.libdevice.fma = language.math.fma +language.extra.ascend.libdevice.abs = language.math.abs diff --git a/python/triton/testing.py.std b/python/triton/testing.py.std new file mode 100644 index 000000000..71cb8ab1e --- /dev/null +++ b/python/triton/testing.py.std @@ -0,0 +1,511 @@ +import functools +import os +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl +from . import runtime + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +def _summarize_statistics(times, quantiles, return_mode): + import torch + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times.tolist() + return getattr(torch, return_mode)(times).item() + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". + :type return_mode: str + """ + import torch + assert return_mode in ["min", "max", "mean", "median", "all"] + + with torch.cuda.stream(torch.cuda.Stream()): + # warmup + fn() + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive, + # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2 + # cache flush). + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for _ in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + import torch + + di = runtime.driver.active.get_device_interface() + + fn() + di.synchronize() + + cache = runtime.driver.active.get_empty_cache_for_benchmark() + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + return _summarize_statistics(times, quantiles, return_mode) + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + """ + Asserts that two inputs are close within a certain tolerance. + + :param x: The first input. + :type x: scala, list, numpy.ndarray, or torch.Tensor + :param y: The second input. + :type y: scala, list, numpy.ndarray, or torch.Tensor + :param atol: The absolute tolerance. Default value is 1e-2. + :type atol: float, optional + :param rtol: The relative tolerance. Default value is 0. + :type rtol: float, optional + :param err_msg: The error message to use if the assertion fails. + :type err_msg: str + """ + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle. + :type styles: list[tuple[str, str]] + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + html.close() + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/ascend/CMakeLists.txt b/third_party/ascend/CMakeLists.txt new file mode 100644 index 000000000..8b77a38ed --- /dev/null +++ b/third_party/ascend/CMakeLists.txt @@ -0,0 +1,7 @@ +#add_subdirectory(triton-adapter triton-adapter) +#add_subdirectory(test) + +add_triton_plugin(TritonAscend ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cpp) +target_include_directories(TritonAscend PRIVATE ${CMAKE_SOURCE_DIR}/third_party/flir/include) + +add_triton_library(Registrar Registrar.cc) diff --git a/third_party/ascend/README.md b/third_party/ascend/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/ascend/Registrar.cc b/third_party/ascend/Registrar.cc new file mode 100644 index 000000000..06eb19729 --- /dev/null +++ b/third_party/ascend/Registrar.cc @@ -0,0 +1,10 @@ +#include "flagtree/Common/UnifiedHardware.h" + +class AscendUnifiedHardware : public mlir::flagtree::UnifiedHardware { +public: +}; + +std::unique_ptr +mlir::flagtree::createUnifiedHardwareManager() { + return std::make_unique(); +} diff --git a/third_party/ascend/backend/__init__.py b/third_party/ascend/backend/__init__.py new file mode 100644 index 000000000..0eec99724 --- /dev/null +++ b/third_party/ascend/backend/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py new file mode 100644 index 000000000..865bee249 --- /dev/null +++ b/third_party/ascend/backend/compiler.py @@ -0,0 +1,505 @@ +import ctypes +import functools +import hashlib +import os +import re +import subprocess +import tempfile +from dataclasses import dataclass +from pathlib import Path +from types import ModuleType +from typing import Any, Dict, Optional, Tuple, Union + +from triton._C.libtriton import ir, passes, ascend +from triton.backends.ascend.utils import ( + _check_bishengir_api_change, + _check_bishengir_is_regbased, + _enable_unpublished_feature, + _get_kernel_target, + _get_llvm_path, + _get_mlir_path, + _get_npucompiler_path, + _get_triton_adapter_opt_path, + _is_ascend_sanitizer_enabled, + _is_debug_line_info_disabled, + _is_auto_map_parallel_blocks_enabled, + downgrade_llir, +) +from triton.backends.ascend.driver import ( + NPUUtils +) +from triton.backends.compiler import ( + AttrsDescriptor, + BaseBackend, + GPUTarget, + register_descriptor, +) +from triton.runtime import driver +from triton.runtime.cache import get_dump_manager + + +# TODO: materialize the concrete min shape +def min_dot_size(target: GPUTarget): + return lambda lhsType, rhsType: (1, 1, 1) + + +def make_ttir(mod, metadata, opt): + if "hash" not in metadata: + metadata["hash"] = hashlib.sha256(f"{mod}-{metadata}".encode()).hexdigest() + # the same optimize pass for triton-ir as all other backends + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + if opt.debug: + dump_manager = get_dump_manager(metadata["hash"]) + print(f"Dumping intermediate results to {dump_manager.cache_dir}") + dump_manager.put(str(mod), "kernel.ttir.mlir", binary=False) + + return mod + + +def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + # Add pass here. + ascend.passes.convert.add_triton_to_linalg_pipeline(pm) + pm.run(mod) + return str(mod) + ''' + with open('/home/zhengyang/FlagTree/triton-op-ir/ops/01_vector_add/cache_ascend/add_kernel.ttadapter', 'r', encoding='utf-8') as f: + content = f.read() + return content + ''' + + +def linalg_to_llir(linalg: str, metadata, opt): + with tempfile.TemporaryDirectory() as tmpdir: + ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") + llmlir_path = os.path.join(tmpdir, "kernel.llir.mlir") + llir_path = os.path.join(tmpdir, "kernel.ll") + Path(ttadapter_path).write_text(linalg) + mlir_opt_path = _get_mlir_path("bin", "mlir-opt") + # TritonAdapter-MLIR to LLVM-MLIR + subprocess.check_call( + [ + mlir_opt_path, + ttadapter_path, + "--convert-linalg-to-affine-loops", + "--eliminate-empty-tensors", + "--empty-tensor-to-alloc-tensor", + "--one-shot-bufferize=allow-return-allocs-from-loops=true", + "--lower-affine", + "--convert-linalg-to-loops", + "--convert-scf-to-cf", + "--convert-cf-to-llvm", + "--convert-arith-to-llvm", + "--convert-math-to-llvm", + "--convert-complex-to-llvm", + "--convert-vector-to-llvm", + "--convert-index-to-llvm", + "--memref-expand", + "--expand-strided-metadata", + "--finalize-memref-to-llvm", + "--convert-func-to-llvm", + # Lowering memrefs creates more affine.apply ops. + # Lowering these affine ops again creates further arith ops, + # so we have to run these two passes again here. + "--lower-affine", + "--convert-arith-to-llvm", + # Remove all unrealized casts created + "--reconcile-unrealized-casts", + "-o", + llmlir_path, + ] + ) + if opt.debug: + dump_manager = get_dump_manager(metadata["hash"]) + dump_manager.put( + Path(llmlir_path).read_text(), "kernel.llir.mlir", binary=False + ) + + # LLVM-MLIR to LLVM-IR + mlir_translate_path = _get_mlir_path("bin", "mlir-translate") + subprocess.check_call( + [mlir_translate_path, llmlir_path, "--mlir-to-llvmir", "-o", llir_path] + ) + if opt.debug: + dump_manager = get_dump_manager(metadata["hash"]) + dump_manager.put(Path(llir_path).read_text(), "kernel.ll", binary=False) + + return Path(llir_path).read_text() + + +def llir_to_cpuasm(llir: str, metadata, opt): + # add metadata at final stage + # Note: Compiled Kernel requires to estimate size of shared memory to occupy + # Currently, CPU backend requires no limit on shared memory size + metadata["shared"] = 1 + # We can get a function name (C naming) from + # LLVM-IR by getting the first "define void @". + fn_name = llir.split("define void @")[1].split("(")[0].strip() + metadata["name"] = fn_name + " cpu" + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.ll") + linked_path = os.path.join(tmpdir, "kernel_linked.ll") + dst_path = os.path.join(tmpdir, "kernel.s") + + llir = downgrade_llir(llir) + if opt.debug: + dump_manager = get_dump_manager(metadata["hash"]) + dump_manager.put(llir, "kernel_downgrade.ll", binary=False) + + Path(src_path).write_text(llir) + + linker_path = _get_llvm_path("bin", "llvm-link") + libclc_path = _get_llvm_path("lib", "clc", "libspirv-aarch64--.bc") + subprocess.check_call( + [ + linker_path, + src_path, + libclc_path, + "--only-needed", + "-S", + "-o", + linked_path, + ] + ) + if opt.debug: + dump_manager = get_dump_manager(metadata["hash"]) + dump_manager.put( + Path(linked_path).read_text(), "kernel_linked.ll", binary=False + ) + + llc_path = _get_llvm_path("bin", "llc") + subprocess.check_call([llc_path, linked_path, "-o", dst_path]) + if opt.debug: + dump_manager = get_dump_manager(metadata["hash"]) + dump_manager.put(Path(dst_path).read_text(), "kernel.s", binary=False) + + # Actually it's text-format assembly. Use read_text(). + return Path(dst_path).read_text() + + +def __get_metadata_attr_by_callback(lib, postfix: str, metadata, meta_key: str): + func_symbol = metadata["kernel_name"] + postfix + if hasattr(lib, func_symbol): + callback_func = getattr(lib, func_symbol) + callback_func.restype = ctypes.c_int64 + callback_func.argtypes = [] + metadata[meta_key] = callback_func() + + +def _parse_linalg_metadata(linalg: str, metadata: dict): + """ + Parse Linalg IR to extract metadata required for NPU compilation. + Extracts and updates the following fields in metadata: + - mix_mode + - kernel_name + - tensor_kinds + - shared (currently hardcoded) + - name (combined kernel_name and mix_mode) + + Additionally, removes the mix_mode attribute from the IR. + """ + # --- Regular expressions and examples --- + + # Example: mix_mode = "aiv" -> aiv + MIX_MODE_REGEX = r'mix_mode\s*=\s*"([^"]+)"' + + # Example: func.func @gather_sorted_kernel(%arg0: ...) -> gather_sorted_kernel + KERNEL_NAME_REGEX = r"func\.func\s+@(\w+)" + + # Example: %arg1: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32} -> ('1', '0') + TENSOR_KIND_REGEX = r'%arg(\d+):[^,)]*?\{[^}]*?tt\.tensor_kind\s*=\s*([^:\s}]+)\s*:[^}]*?\}' + + # Example removal: ', mix_mode = "aiv"' → '' + REMOVE_MIX_MODE_REGEX = r', mix_mode\s*=\s*"[^"]*"' + + # Note: Compiled Kernel requires to estimate size of shared memory to occupy + # Currently, NPU backend does not limit on shared memory + metadata["shared"] = 1 + # the mix mode is also encoded into metadata['name'] for runtime to distinguish + metadata["mix_mode"] = re.search(MIX_MODE_REGEX, linalg).group(1) + metadata["kernel_name"] = re.search(KERNEL_NAME_REGEX, linalg).group(1) + # Use while space to split kernel_name and mix_mode. + # Check the function load_binary in npu_driver.py. + metadata["name"] = metadata["kernel_name"] + " " + metadata["mix_mode"] + # Parse all tensor kinds from arguments + metadata["tensor_kinds"] = [int(kind) for _, kind in re.findall(TENSOR_KIND_REGEX, linalg)] + # remove the mix_mode attribute + linalg = re.sub(REMOVE_MIX_MODE_REGEX, "", linalg) + return linalg, metadata + + +def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt): + linalg, metadata = _parse_linalg_metadata(linalg, metadata) + with tempfile.TemporaryDirectory() as tmpdir: + ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir") + Path(ttadapter_path).write_text(linalg) + bin_file = os.path.join(tmpdir, "kernel") + if _check_bishengir_api_change(): + bin_file_with_ext = "kernel.o" + else: + bin_file_with_ext = "kernel_reloc.o" + if _check_bishengir_is_regbased(): + bishengir_hivm_opt = "--reg-based=true" + else: + bishengir_hivm_opt = "--enable-hivm-compile=true" + bin_path = os.path.join(tmpdir, bin_file_with_ext) + callback_path = os.path.join(tmpdir, "libkernel.so") + _compile_option_list = [] + if _enable_unpublished_feature(): + _compile_option_list += [ + f"--target={NPUUtils().get_arch()}", + ] + multibuffer = metadata["multibuffer"] + if multibuffer is not None: + _compile_option_list += [ + f"--enable-auto-multi-buffer={multibuffer}", + ] + if _is_ascend_sanitizer_enabled(): + _compile_option_list += ["--enable-sanitizer=true"] + if not _is_debug_line_info_disabled(): + _compile_option_list += ["--enable-debug-info=true"] + + enable_hivm_auto_cv_balance = metadata["enable_hivm_auto_cv_balance"] + if enable_hivm_auto_cv_balance is not None: + _compile_option_list += \ + [f"--enable-hivm-auto-cv-balance={enable_hivm_auto_cv_balance}"] + + unit_flag = metadata["unit_flag"] + if unit_flag is not None: + _compile_option_list += \ + [f"--enable-hivm-unit-flag-sync={unit_flag}"] + + inject_barrier_all = metadata["inject_barrier_all"] + if inject_barrier_all is not None: + _compile_option_list += \ + [f"--enable-hivm-inject-barrier-all-sync={inject_barrier_all}"] + + limit_auto_multi_buffer_only_for_local_buffer = metadata["limit_auto_multi_buffer_only_for_local_buffer"] + if limit_auto_multi_buffer_only_for_local_buffer is not None: + _compile_option_list += \ + [f"--limit-auto-multi-buffer-only-for-local-buffer={limit_auto_multi_buffer_only_for_local_buffer}"] + + set_workspace_multibuffer = metadata["set_workspace_multibuffer"] + if set_workspace_multibuffer is not None: + _compile_option_list += \ + [f"--set-workspace-multibuffer={set_workspace_multibuffer}"] + + tile_mix_vector_loop = metadata["tile_mix_vector_loop"] + if tile_mix_vector_loop is not None: + _compile_option_list += \ + [f"--tile-mix-vector-loop={tile_mix_vector_loop}"] + + tile_mix_cube_loop = metadata["tile_mix_cube_loop"] + if tile_mix_cube_loop is not None: + _compile_option_list += \ + [f"--tile-mix-cube-loop={tile_mix_cube_loop}"] + + auto_multi_buffer = metadata["limit_auto_multi_buffer_of_local_buffer"] + if auto_multi_buffer is not None: + _compile_option_list += \ + [f"--limit-auto-multi-buffer-of-local-buffer={auto_multi_buffer}"] + + if _is_auto_map_parallel_blocks_enabled(): + _compile_option_list += ["--enable-auto-blockify-loop"] + npu_compiler_path = _get_npucompiler_path() + if npu_compiler_path.endswith("bishengir-compile"): + _compile_option_list += [ + "--enable-hfusion-compile=true", + bishengir_hivm_opt, + "--enable-triton-kernel-compile=true", + ] + cmd_list = ( + [npu_compiler_path, ttadapter_path] + + _compile_option_list + + ["-o", bin_file] + ) + ret = subprocess.run(cmd_list, capture_output=True, check=True) + if Path(callback_path).is_file(): + lib = ctypes.CDLL(callback_path) + __get_metadata_attr_by_callback(lib, "_infer_workspace_shape_function", metadata, "workspace_size") + __get_metadata_attr_by_callback(lib, "_infer_sync_block_lock_num_function", metadata, "lock_num") + __get_metadata_attr_by_callback(lib, "_infer_sync_block_lock_init_function", metadata, "lock_init_val") + + return Path(bin_path).read_bytes() + + +@dataclass(frozen=True) +class NPUOptions: + debug: bool = False + sanitize_overflow: bool = True + llvm_version: int = 15 + kernel_name: str = "triton_" + + cluster_dims: tuple = (1, 1, 1) + num_warps: int = -1 + num_ctas: int = -1 + num_stages: int = 2 + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 + + enable_warp_specialization: bool = False + enable_nd2nz_on_vector: bool = False + enable_persistent: bool = False + optimize_epilogue: bool = False + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + allowed_dot_input_precisions: Tuple[str] = ("ieee", "hf32") + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + + multibuffer: bool = None + enable_hivm_auto_cv_balance: bool = None + unit_flag: bool = None + inject_barrier_all: bool = None + limit_auto_multi_buffer_only_for_local_buffer: bool = None + limit_auto_multi_buffer_of_local_buffer: str = None + set_workspace_multibuffer: int = None + tile_mix_vector_loop: int = None + tile_mix_cube_loop: int = None + + stream: int = None + + def hash(self): + key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +@dataclass(frozen=True) +class CPUOptions: + debug: bool = False + llvm_version: int = 15 + kernel_name: str = "triton_" + + cluster_dims: tuple = (1, 1, 1) + num_warps: int = -1 + num_ctas: int = -1 + num_stages: int = -1 + + enable_warp_specialization: bool = False + enable_persistent: bool = False + optimize_epilogue: bool = False + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + + def hash(self): + key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +@register_descriptor +class AscendAttrsDescriptor(AttrsDescriptor): + + # For now we collect shapes of tensor at runtime. + # We comment out the following func but keep it for future reference. + def _add_backend_properties(self, params=None, values=None): + pass + + +class AscendBackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == "cpu" or target.backend == "npu" + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + if target.backend == "cpu": + self.binary_ext = "cpuasm" + elif target.backend == "npu": + self.binary_ext = "npubin" + + def parse_options(self, opts) -> Any: + # TODO: get available targets when building options? + if self.target.backend == "npu": + args = { + k: opts[k] + for k in NPUOptions.__dataclass_fields__.keys() + if k in opts + } + options = NPUOptions(**args) + else: + args = { + k: opts[k] + for k in CPUOptions.__dataclass_fields__.keys() + if k in opts + } + options = CPUOptions(**args) + return options + + def pack_metadata(self, metadata): + # collect necessary metadata to launch kernels + # TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 could set unique name. + # Get this name as the kernel_name to CANN runtime. + # kernel_name is unique to Ascend backend and should not be public. + # CANN runtime limits the length of kernel name <= 50. + # Considering '\n' is appended, thus the real kernel name <= 49. + KERNEL_NAME_MAX_LEN = 49 + kernel_name_orig, mix_mode = metadata.name.split() + if len(kernel_name_orig) > KERNEL_NAME_MAX_LEN: + kernel_name = kernel_name_orig[-KERNEL_NAME_MAX_LEN:] + else: + kernel_name = kernel_name_orig + return { + "kernel_name": kernel_name, + "hash": metadata.hash, + "debug": metadata.debug, + "tensor_kinds": metadata.tensor_kinds, + } + + def get_codegen_implementation(self): + # Note: a dict of functions is required to generate vendor-specific code piecies + # e.g. convert custom types like fp8e4b15 + codegen_fns = {"min_dot_size": min_dot_size(self.target)} + return codegen_fns + + def load_dialects(self, ctx): + pass + + def get_attrs_descriptor(self, params, args): + return AscendAttrsDescriptor(params, args) + + def add_stages(self, stages, options): + if self.target.backend == "npu": + stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) + stages["ttadapter"] = lambda src, metadata: ttir_to_linalg( + src, metadata, options, named_ops=True + ) + stages["npubin"] = ( + lambda src, metadata: linalg_to_bin_enable_npu_compile( + src, metadata, options + ) + ) + else: + stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options) + stages["ttadapter"] = lambda src, metadata: ttir_to_linalg( + src, metadata, options + ) + stages["llir"] = lambda src, metadata: linalg_to_llir( + src, metadata, options + ) + stages["cpuasm"] = lambda src, metadata: llir_to_cpuasm( + src, metadata, options + ) + + @functools.lru_cache() + def hash(self): + # TODO fetch compiler version + version_key = self.target + return str(version_key) + + def get_module_map(self) -> Dict[str, ModuleType]: + return {} diff --git a/third_party/ascend/backend/cpu_driver.py b/third_party/ascend/backend/cpu_driver.py new file mode 100644 index 000000000..70b974e25 --- /dev/null +++ b/third_party/ascend/backend/cpu_driver.py @@ -0,0 +1,173 @@ +from triton.runtime.cache import get_cache_manager, get_dump_manager +from pathlib import Path +import tempfile +import os +import sysconfig +import subprocess +import importlib +from triton.backends.ascend.utils import _get_llvm_path + +# TODO: temporarily fake CPUUtils class +class CPUUtils(object): + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(CPUUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + pass + + def get_device_properties(self, device): + # temperoarily added properties to avoid triton-compiler complain + # fetch available memory at runtime + return {"max_shared_mem": 1} + + def load_binary(self, name, kernel, shared, device): + # TODO (temperoarily fake function) load a binary from binary object to device + # return value are: (mod, funcptr/handle, n_regs, n_spills) + return None, kernel, 0, 0 + +class CPULauncher(object): + def __init__(self, src, metadata): + kernel_name = metadata.name.split()[0] + signature = src.signature + constants = src.constants + launcher_src = generate_cpu_wrapper_src(constants, signature, kernel_name) + self.launch = compile_module(launcher_src) + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class CPUDriver: + def __init__(self): + self.utils = CPUUtils() + self.launcher_cls = CPULauncher + super().__init__() + + def get_current_target(self): + # TODO: do we rely on CPU arch? + return ("cpu", "arm-64") + + def get_current_device(self): + """ + Get current device + """ + # TODO: dummy device-getter for cpu backend + return 0 + + def set_current_device(self, device): + """ + Set current device as the given device + """ + # TODO: dummy device-setter for cpu backend + return + + def get_current_stream(self, device): + """ + Get stream for current device + """ + # TODO: dummy stream api for cpu backend. + return 0 + +# the template is from triton-adapter HEAD. Wrapping the generated kernel assembly into a python module +def generate_cpu_wrapper_src(constants, signature, kernel_name): + def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + def _extracted_ty(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] + + def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + def _generate_launcher(constants, signature, kernel_name): + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + format = "iiiOOO" + ''.join([_format_of(_extracted_ty(ty)) for ty in signature.values()]) + # to be filled + return f""" + """ + launcher_src = _generate_launcher(constants, signature, kernel_name) + return launcher_src + +def compile_module(launcher_src): + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + + def launch(gridX, gridY, gridZ, stream, cu_function, + packed_metadata, launch_metadata, + launch_enter_hook, launch_exit_hook, + *args): + kernel_name = packed_metadata["kernel_name"] + cache = get_cache_manager(packed_metadata["hash"]) + filename = f"{kernel_name}_cpu_launcher.so" + cache_path = cache.get_file(filename) + if cache_path is None: + asm_src = cu_function + with tempfile.TemporaryDirectory() as tmpdir: + asm_src_path = os.path.join(tmpdir, "kernel.s") + launcher_src_path = os.path.join(tmpdir, "main.cxx") + if packed_metadata["debug"]: + dump_manager = get_dump_manager(packed_metadata["hash"]) + dump_manager.put(launcher_src, "kernel_cpu_launcher.cxx", binary=False) + so_path = os.path.join(tmpdir, "kernel.so") + Path(asm_src_path).write_bytes(asm_src) + Path(launcher_src_path).write_text(launcher_src) + # Compile it together. + subprocess.check_call([_get_llvm_path("bin", "clang++"), launcher_src_path, asm_src_path, f"-I{py_include_dir}", f"-I{Path(__file__).resolve().parent}", "-shared", "-fPIC", "-o", so_path]) + + with open(so_path, "rb") as f: + cache_path = cache.put(f.read(), filename, binary=True) + + # Load and launch the compiled kernel. + spec = importlib.util.spec_from_file_location("__triton_adapter_ref_cpu_kernel_launcher", cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.launch(gridX, gridY, gridZ, launch_enter_hook, launch_exit_hook, packed_metadata, *args) + + return launch \ No newline at end of file diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py new file mode 100644 index 000000000..8ea0873a8 --- /dev/null +++ b/third_party/ascend/backend/driver.py @@ -0,0 +1,728 @@ +from pathlib import Path +import tempfile +import os +import os.path +import re +import subprocess +import sysconfig +from typing import Optional +import functools +import hashlib +from triton.runtime.cache import get_cache_manager, get_dump_manager +from triton.backends.driver import DriverBase +from triton.backends.compiler import GPUTarget +from triton.backends.ascend.utils import ( + _build_npu_ext, + _check_cxx11_abi, + convert_sigtype_to_int, + _is_auto_map_parallel_blocks_enabled, +) + +class NPUUtils(object): + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(NPUUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + dirname = os.path.dirname(os.path.realpath(__file__)) + src = Path(os.path.join(dirname, "npu_utils.cpp")).read_text() + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + fname = "npu_utils.so" + cache_path = cache.get_file(fname) + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "npu_utils.cpp") + with open(src_path, "w") as f: + f.write(src) + so = _build_npu_ext("npu_utils", src_path, tmpdir) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), fname, binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location("npu_utils", cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.npu_utils_mod = mod + + def load_binary(self, name, kernel, shared, device): + fnname, mix_mode = name.split() + return self.npu_utils_mod.load_kernel_binary(fnname, kernel, shared, device, mix_mode) + + @functools.lru_cache() + def get_device_properties(self, device): + # temperoarily added "max_shared_mem" properties to avoid triton-compiler complain + # fetch available memory at runtime + num_aic = self.get_aicore_num() + num_aiv = num_aic * 2 + return {"max_shared_mem": 1, "num_aicore": num_aic, "num_vectorcore": num_aiv} + + @functools.lru_cache() + def get_arch(self): + # temporarily return empty arch descriptor + return self.npu_utils_mod.get_arch() + + @functools.lru_cache() + def get_aicore_num(self): + # temporarily return empty arch descriptor + return self.npu_utils_mod.get_aicore_num() + + @functools.lru_cache() + def get_aivector_core_num(self): + return self.get_device_properties("npu")["num_vectorcore"] + + +class NPULauncher(object): + def __init__(self, src, metadata): + debug_mode = metadata.debug + workspace_size = int(metadata.workspace_size) \ + if hasattr(metadata, 'workspace_size') else -1 + lock_init_value = int(metadata.lock_init_value) \ + if hasattr(metadata, 'lock_init_value') else 0 + lock_num = int(metadata.lock_num) \ + if hasattr(metadata, 'lock_num') else -1 + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + mix_mode = metadata.mix_mode + wrapper_src = generate_npu_wrapper_src(constants, signature, \ + workspace_size, mix_mode, \ + lock_num, lock_init_value) + so_launcher_path = make_npu_launcher_stub(wrapper_src, debug_mode) + # initialize launcher + import importlib.util + spec = importlib.util.spec_from_file_location("__triton_launcher", so_launcher_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.launch = getattr(mod, "launch") + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class NPUDriver(DriverBase): + def __init__(self): + self.utils = NPUUtils() + self.launcher_cls = NPULauncher + super().__init__() + + @classmethod + def is_active(cls): + def test_npucompiler(): + from triton.backends.ascend.utils import _get_bisheng_path + npucompiler = _get_bisheng_path() + targets = subprocess.check_output([npucompiler, "-print-targets"]).decode().strip().split() + return "hiipu64" in targets + try: + return test_npucompiler() + except Exception as e_npucompiler: + import warnings + red = "\x1b[31;20m" + reset = "\x1b[0m" + warnings.warn(red + str(e_npucompiler) + reset) + return False + + def get_current_target(self): + backend = "npu" + arch = self.utils.get_arch() + warp_size = 0 + return GPUTarget(backend, arch, warp_size) + + def get_current_device(self): + """ + Get current device + """ + import torch + import torch_npu + return torch.npu.current_device() + + def set_current_device(self, device): + """ + Set current device as the given device + """ + import torch + import torch_npu + return torch.npu.set_device(device) + + def get_current_stream(self, device: Optional[int] = None) -> int: + """ + Get stream for current device + """ + # According to torch_npu, the content of a torch.npu.Stream is essentilly an rtStream_t + # TODO: use CANN API instead of torchnpu + import torch + import torch_npu + if device is None: + device = self.get_current_device() + return torch.npu.current_stream(device).npu_stream + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_device_interface(self): + import torch + return torch.npu + + def get_empty_cache_for_benchmark(self): + import torch + cache_size = 192 * 1024 * 1024 + return torch.empty(cache_size // 4, dtype=torch.int, device='npu') + + +def make_npu_launcher_stub(src, debug=False): + """ + Generate the launcher stub to launch the kernel + """ + # try to get cached file + so_cache_key = hashlib.sha256(src.encode("utf-8")).hexdigest() + so_cache_manager = get_cache_manager(so_cache_key) + # append the cxx11_abi value to the launcher name to avoid + # linking to a launcher with wrong cxx11_abi. + use_cxx11_abi = _check_cxx11_abi() + name = f"launcher_cxx11abi{use_cxx11_abi}" + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so_name = f"{name}{suffix}" + + if debug: + dump_manager = get_dump_manager(so_cache_key) + print(f"Dumping {name}.cxx to {dump_manager.cache_dir}") + dump_manager.put(src, f"{name}.cxx", binary = False) + + cache_path = so_cache_manager.get_file(so_name) + if cache_path is not None: + return cache_path + + with tempfile.TemporaryDirectory() as tmpdir: + if debug: + so_cache_manager.put(src, f"{name}.cxx", binary=False) + src_path = os.path.join(tmpdir, f"{name}.cxx") + with open(src_path, "w") as f: + f.write(src) + enable_taskqueue = os.getenv("TRITON_ENABLE_TASKQUEUE", 'true').lower() in ('true', '1') + if (enable_taskqueue): + kernel_launcher_type = "torch" + else: + kernel_launcher_type = None + so = _build_npu_ext(name, src_path, tmpdir, kernel_launcher=kernel_launcher_type) + if debug: + with open(so, "rb") as f: + return dump_manager.put(f.read(), so_name, binary=True) + with open(so, "rb") as f: + return so_cache_manager.put(f.read(), so_name, binary=True) + + +def extract_device_print_code_from_cann(): + from triton.backends.ascend.utils import _get_bisheng_path + ccec_compiler_bin_folder, _ = os.path.split(os.path.realpath(_get_bisheng_path())) + ccec_compiler_folder, _ = os.path.split(ccec_compiler_bin_folder) + clang_version = os.listdir(os.path.join(ccec_compiler_folder, "lib/clang/"))[0] + ccelib_path = os.path.join(ccec_compiler_folder, f"lib/clang/{clang_version}/include/ccelib") + + def read_header(header_path): + with open(os.path.join(ccelib_path, header_path), 'r') as f: + code = f.read() + + # remove all #include "..." + lines = code.splitlines() + purged_lines = [] + for line in lines: + normalized_line = ' '.join(line.split()) + if not normalized_line.startswith('#include "'): + purged_lines.append(line) + code = '\n'.join(purged_lines) + + # remove [aicore] functions + aicore_positions = [] + for m in re.finditer('\[aicore\]', code): + aicore_positions.append(m.start()) + + def find_aicore_function_span(src, pos): + for i in range(pos - 1, -1, -1): + if src[i] == '}': # this relies on that all [aicore] functions come after normal functions + left = i + 1 + break + n = len(src) + brace_nest = 0 + for j in range(pos, n, 1): + if src[j] == '{': + brace_nest += 1 + elif src[j] == '}': + brace_nest -= 1 + if brace_nest == 0: + right = j + break + return left, right + + new_code = '' + segment_start = 0 + for pos in aicore_positions: + left, right = find_aicore_function_span(code, pos) + new_code += code[segment_start:left] + segment_start = right + 1 + new_code += code[segment_start:] + + # remove __gm__ and rename macros + new_code = new_code.replace('__gm__', ' ') + new_code = new_code.replace('__CCELIB_RT_ERROR_NONE', 'RT_ERROR_NONE') + new_code = new_code.replace('__CCELIB_RT_MEMORY_HBM', 'RT_MEMORY_HBM') + new_code = new_code.replace('__CCELIB_RT_MEMCPY_HOST_TO_DEVICE', 'RT_MEMCPY_HOST_TO_DEVICE') + new_code = new_code.replace('__CCELIB_RT_MEMCPY_DEVICE_TO_HOST', 'RT_MEMCPY_DEVICE_TO_HOST') + return new_code + + # the following headers should be included in this order + return '\n'.join([ + read_header('common/common_impl.h'), + read_header('internal/debug_tunnel/payload.h'), + read_header('internal/debug_tunnel/payload_impl.h'), + read_header('internal/debug_tunnel/tunnel.h'), + read_header('internal/debug_tunnel/tunnel_impl.h') + ]) + + +# the template is from triton-adapter HEAD. Wrapping the generated kernel binary into a python module +def generate_npu_wrapper_src(constants, signature, workspace_size, mix_mode, lock_num, lock_ini_val): + import os + + def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + def _extracted_ty(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] + + def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + """ + args: + int gridX, gridY, gridZ; + rtStream_t stream; + const void *functon; + PyObject* packed_metadata, *launch_metadata; + PyObject* launch_enter_hook, *launch_exit_hook; + *args_expand + """ + format = "iiiKKOOOO" + ''.join([_format_of(_extracted_ty(ty)) for ty in signature.values()]) + + grid_info = {'X': 'i32', 'Y': 'i32', 'Z': 'i32'} + + enable_device_print = os.getenv( + "TRITON_DEVICE_PRINT", 'false').lower() in ('true', '1') + enable_taskqueue = os.getenv( + "TRITON_ENABLE_TASKQUEUE", 'true').lower() in ('true', '1') + enable_auto_map_parallel_blocks = _is_auto_map_parallel_blocks_enabled() + npu_utils = NPUUtils() + num_physical_blocks = npu_utils.get_aivector_core_num( + ) if mix_mode == "aiv" else npu_utils.get_aicore_num() + task_type = "MSPROF_GE_TASK_TYPE_AIV" if mix_mode == "aiv" else "MSPROF_GE_TASK_TYPE_AI_CORE" + LINE_CHANGE_CHAR = chr(10) # it is \n + alloc_success_code = 'return 1;' + sync_lock_fail_code = 'fprintf(stderr, "Error: syncBlockLock allocation failed\\n"); return;' + workspace_fail_code = 'fprintf(stderr, "Error: workspace allocation failed\\n"); return;' + + cpp_device_pointer = """ +typedef struct _DevicePtrInfo { + void *dev_ptr; + bool valid; +} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) { + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) { + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); + return ptr_info; + } + if (obj == Py_None) { + // valid nullptr + return ptr_info; + } + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) { + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + } + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); + if(!ptr_info.dev_ptr) + return ptr_info; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + } + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +} +""" + + cpp_msprof_extern = """ +extern "C" { + typedef int (* callback)(unsigned int type, void* data, unsigned int len); + extern int MsprofReportApi(unsigned int agingFlag, const MsprofApi *api); + extern unsigned long int MsprofSysCycleTime(); + extern int MsprofRegisterCallback(unsigned int moduleId, callback handle); + static unsigned int __MsprofFlagL0 = 0; + static unsigned int __MsprofFlagL1 = 0; + + int ProfCtrlHandle(unsigned int CtrlType, void* CtrlData, unsigned int DataLen) { + if ((CtrlData == nullptr) || (DataLen == 0U)) { + return 1; + } + + if (CtrlType == 1) { + MsprofCommandHandle* handle = (MsprofCommandHandle *)(CtrlData); + if (handle->type >= 6) // 6 is not used here + return 1; + if (handle->type == 1) { // init - 0 , start - 1 + __MsprofFlagL0 = ((0x00000800ULL & handle->profSwitch) == 0x00000800ULL) ? 1 : 0; + __MsprofFlagL1 = ((0x00000002ULL & handle->profSwitch) == 0x00000002ULL) ? 1 : 0; + } + } + return 0; + } +} +""" + + cpp_msprof_callback = """ + MsprofRegisterCallback(8, ProfCtrlHandle); // 8 - CCE defined in msprof headerfile slog.h +""" + + cpp_msprof_call_before_launch = """ + unsigned long int beginTime = 0; + unsigned long int endTime = 0; + unsigned long int opNameHashID = 0; + unsigned int threadId = 0; + char* _kernelName = const_cast(name.c_str()); + size_t length = name.length(); + if (__MsprofFlagL0 || __MsprofFlagL1) + { + beginTime = MsprofSysCycleTime(); + } +""" + + cpp_msprof_call_after_launch = f""" + if (__MsprofFlagL0 || __MsprofFlagL1) + {{ + endTime = MsprofSysCycleTime(); + opNameHashID = MsprofGetHashId(_kernelName, length); + threadId = (unsigned int)(syscall(SYS_gettid)); + MsprofApi info; + info.level = MSPROF_REPORT_NODE_LEVEL; + info.magicNumber = 0x5a5a; //MSPROF_REPORT_DATA_MAGIC_NUM + info.type = MSPROF_REPORT_NODE_LAUNCH_TYPE; + info.threadId = threadId; + info.reserve = 0; + info.beginTime = beginTime; + info.endTime = endTime; + info.itemId = opNameHashID; + MsprofReportApi(false, &info); + }} + if (__MsprofFlagL1) + {{ + MsprofCompactInfo nodeBasicInfo; + nodeBasicInfo.level = MSPROF_REPORT_NODE_LEVEL; + nodeBasicInfo.magicNumber = 0x5a5a; //MSPROF_REPORT_DATA_MAGIC_NUM + nodeBasicInfo.type = MSPROF_REPORT_NODE_BASIC_INFO_TYPE; + nodeBasicInfo.threadId = threadId; + nodeBasicInfo.timeStamp = endTime; + nodeBasicInfo.data.nodeBasicInfo.opName = opNameHashID; + nodeBasicInfo.data.nodeBasicInfo.opType = opNameHashID; + nodeBasicInfo.data.nodeBasicInfo.taskType = {task_type}; + nodeBasicInfo.data.nodeBasicInfo.blockDim = blockNum; + MsprofReportCompactInfo(0, static_cast(&nodeBasicInfo), sizeof(MsprofCompactInfo)); + + // Report tensor info + int max_tensors_num = tensorShapes.size() < MSPROF_GE_TENSOR_DATA_NUM ? tensorShapes.size() : MSPROF_GE_TENSOR_DATA_NUM; + MsprofAdditionalInfo tensorInfo; + tensorInfo.level = MSPROF_REPORT_NODE_LEVEL; + tensorInfo.type = MSPROF_REPORT_NODE_TENSOR_INFO_TYPE; + tensorInfo.threadId = threadId; + tensorInfo.timeStamp = endTime; + auto profTensorData = reinterpret_cast(tensorInfo.data); + profTensorData->opName = opNameHashID; + int tensorCount = 0; + int dataTypes[MSPROF_GE_TENSOR_DATA_NUM]; + if (tensorShapes.size() > 0) {{ + {LINE_CHANGE_CHAR.join( + f'dataTypes[{i}] = {convert_sigtype_to_int(ty[1:])};' + for i, ty in signature.items() + if ty.startswith("*") and i < 5 + )} + }} + for (int i = 0; i < tensorShapes.size() && tensorCount < MSPROF_GE_TENSOR_DATA_NUM; i++) {{ + auto fillTensorData = [&](int index, int tensorType) {{ + profTensorData->tensorData[index].tensorType = tensorType; + profTensorData->tensorData[index].format = 2; // GeDataFormat: ND = 2 + profTensorData->tensorData[index].dataType = dataTypes[i]; + int nDim = tensorShapes[i].size(); + nDim = nDim < MSPROF_GE_TENSOR_DATA_SHAPE_LEN ? nDim : MSPROF_GE_TENSOR_DATA_SHAPE_LEN; + for (int j = 0; j < nDim; j++) {{ + profTensorData->tensorData[index].shape[j] = tensorShapes[i][j]; + }} + for (int j = nDim; j < MSPROF_GE_TENSOR_DATA_SHAPE_LEN; j++) {{ + profTensorData->tensorData[index].shape[j] = 0; + }} + }}; + int tensorType = (i < tensorKinds.size()) ? tensorKinds[i] : 0; // DeFault tensor type is input + if (tensorType == TENSOR_KIND_INPUT || tensorType == TENSOR_KIND_INPUT_OUTPUT) {{ + fillTensorData(tensorCount, MSPROF_GE_TENSOR_TYPE_INPUT); + tensorCount++; + }} + if ((tensorType == TENSOR_KIND_OUTPUT || tensorType == TENSOR_KIND_INPUT_OUTPUT) && tensorCount < MSPROF_GE_TENSOR_DATA_NUM){{ + fillTensorData(tensorCount, MSPROF_GE_TENSOR_TYPE_OUTPUT); + tensorCount++; + }} + }} + profTensorData->tensorNum = tensorCount; + MsprofReportAdditionalInfo(false, static_cast(&tensorInfo), sizeof(MsprofAdditionalInfo)); + }} +""" + + return f""" +#include +#include +#include +#include +#include +#define PY_SSIZE_T_CLEAN +#include +{'#include ' if enable_taskqueue else ''} +#include "experiment/runtime/runtime/rt.h" +#include +{extract_device_print_code_from_cann() if enable_device_print else ''} + +#define TENSOR_KIND_INPUT 0 +#define TENSOR_KIND_OUTPUT 1 +#define TENSOR_KIND_INPUT_OUTPUT 2 + +{cpp_msprof_extern} + +{cpp_device_pointer} + +static void _launch(const char* kernelName, const void* func, rtStream_t stream, int gridX, int gridY, int gridZ, std::vector> &tensorShapes, std::vector &tensorKinds{', ' + arg_decls if len(signature) > 0 else ''}) {{ + // only 1D parallelization is supported for NPU + // Pointer type becomes flattend 1-D Memref tuple: base_ptr, data_ptr, offset, shape, stride + // base_ptr offset shape and stride are not used, arbitrarily set for now + std::string name = ""; + name.append(kernelName); + {'auto launch_call = [=]() -> rtError_t' if enable_taskqueue else ''} {{ + uint32_t blockNum = gridX * gridY * gridZ; + {'blockNum = std::min(blockNum, (uint32_t)' + str(num_physical_blocks) + ');' if enable_auto_map_parallel_blocks else ''} + {'cce::internal::DebugTunnelData *DTData = cce::internal::DebugTunnel::Open(blockNum);' if enable_device_print else ''} + rtError_t ret; + void *ffts_addr = NULL; + uint32_t ffts_len; ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len); + if (ret != RT_ERROR_NONE) {{ + return {'ret' if enable_taskqueue else ''}; + }} + c10::DataPtr syncBlockLock_ptr; + c10::DataPtr workspace_addr_ptr; + uint16_t ModuleId = 0; + auto* npu_allocator = c10_npu::NPUCachingAllocator::get(); + {f''' + uint64_t syncBlockLockSize = {lock_num} * sizeof(int64_t); + syncBlockLock_ptr = npu_allocator->allocate(syncBlockLockSize); + if (!syncBlockLock_ptr) {{ + {alloc_success_code if enable_taskqueue else sync_lock_fail_code} + }} + std::vector lockInitData({lock_num}, {lock_ini_val}); + ret = rtMemcpy( + syncBlockLock_ptr.get(), syncBlockLockSize, + reinterpret_cast(lockInitData.data()), syncBlockLockSize, + RT_MEMCPY_HOST_TO_DEVICE + ); + if (ret != RT_ERROR_NONE) {{ + return {'ret' if enable_taskqueue else ''}; + }} + ''' if lock_num > 0 else ''} + {f''' + uint64_t totalWorkSpaceSize = {workspace_size} * blockNum; + workspace_addr_ptr = npu_allocator->allocate(totalWorkSpaceSize); + if (!workspace_addr_ptr) {{ + {alloc_success_code if enable_taskqueue else workspace_fail_code} + }} + ''' if workspace_size > 0 else ''} + struct __attribute__((packed)) {{ + void* ffts_addr __attribute__((aligned(8))); + void* syncBlockLock __attribute__((aligned(8))); + void* workspace_addr __attribute__((aligned(8))); + {' '.join(f'{_ty_to_cpp(ty)} arg{i} __attribute__((aligned({4 if ty[0] != "*" and ty[-2:] != "64" else 8})));' for i, ty in signature.items() if i not in constants)} + {' '.join(f'{_ty_to_cpp(ty)} grid{mark} __attribute__((aligned(4)));' for mark, ty in grid_info.items())} + {'void* DTData __attribute__((aligned(8)));' if enable_device_print else ''} + }} args = {{ + static_cast(ffts_addr), + {f'syncBlockLock_ptr.get()' if lock_num > 0 else 'nullptr'}, + {f'workspace_addr_ptr.get()' if workspace_size > 0 else 'nullptr'}, + {(', '.join(f'static_cast<{_ty_to_cpp(ty)}>(arg{i})' for i, ty in signature.items() if i not in constants) + ',') if len(signature) > 0 else ''} + {', '.join(f'static_cast<{_ty_to_cpp(ty)}>(grid{mark})' for mark, ty in grid_info.items())} + {', static_cast(DTData)' if enable_device_print else ''} + }}; + {cpp_msprof_call_before_launch} + ret = rtKernelLaunch(func, blockNum, static_cast(&args), sizeof(args), NULL, stream); + {'void *&stream_ref = const_cast(stream);' if enable_device_print else ''} + {'cce::internal::DebugTunnel::Close(DTData, stream_ref);' if enable_device_print else ''} + {cpp_msprof_call_after_launch} + {'return ret;' if enable_taskqueue else ''} + }}; + {'at_npu::native::OpCommand cmd; cmd.Name(name.c_str()).SetCustomHandler(launch_call).Run();' if enable_taskqueue else ''} + return; +}} + +// Extract tensor shape from PyObject +static std::vector _get_tensor_shape(PyObject *tensor) {{ + std::vector shape; + + // Early return if tensor is None or null + if (!tensor || tensor == Py_None) {{ + return shape; + }} + + // Calling tensor.size() + PyObject* size_result = PyObject_CallMethod(tensor, "size", NULL); + if (!size_result) {{ + return shape; + }} + // Using PySequence_Fast to improve access efficiency + PyObject* seq = PySequence_Fast(size_result, "Expected a sequence from tensor.size()"); + if (seq) {{ + Py_ssize_t len = PySequence_Fast_GET_SIZE(seq); + PyObject** items = PySequence_Fast_ITEMS(seq); + for (Py_ssize_t i = 0; i < len; ++i) {{ + PyObject* dim = items[i]; + if (PyLong_Check(dim)) {{ + shape.push_back(PyLong_AsLong(dim)); + }} + }} + }} + Py_DECREF(seq); + Py_DECREF(size_result); + return shape; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + rtStream_t stream; + const void *function; + PyObject *packedMetadata = NULL; + PyObject *launch_metadata = NULL; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + std::vector> tensorShapes; + {' '.join([f"{_extracted_ty(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple( + args, \"{format}\", + &gridX, &gridY, &gridZ, &stream, &function, + &packedMetadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook + {', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''} + ) + ) {{ + return NULL; + }} + if (__MsprofFlagL1) + {{ + { + LINE_CHANGE_CHAR.join( + f"{{ auto tmp = _get_tensor_shape(_arg{i}); if (!tmp.empty()) tensorShapes.push_back(tmp); }}" + for i, ty in signature.items() if ty[0] == "*" + ) + } + }} + + if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ + return NULL; + }} + + // get kernel_name + PyObject *kernelNameObj = PyDict_GetItemString(packedMetadata, "kernel_name"); + const char *kernelName = PyUnicode_AsUTF8(kernelNameObj); + // get tensor_kinds + std::vector tensorKinds; + PyObject *tensorKindList = PyDict_GetItemString(packedMetadata, "tensor_kinds"); + if (tensorKindList) {{ + int size = PyObject_Size(tensorKindList); + for (int i = 0; i < size; i++) {{ + PyObject *kind = PySequence_GetItem(tensorKindList, i); + tensorKinds.push_back(PyLong_AsLong(kind)); + }} + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0]=="*" else "" for i, ty in signature.items()])}; + _launch(kernelName, function, stream, gridX, gridY, gridZ, tensorShapes, tensorKinds{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + if (PyErr_Occurred()) {{ + return NULL; + }} + if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ + return NULL; + }} + Py_RETURN_NONE; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + {cpp_msprof_callback} + return m; +}} +""" diff --git a/third_party/ascend/backend/name.conf b/third_party/ascend/backend/name.conf new file mode 100644 index 000000000..037e3b658 --- /dev/null +++ b/third_party/ascend/backend/name.conf @@ -0,0 +1 @@ +ascend diff --git a/third_party/ascend/backend/npu_utils.cpp b/third_party/ascend/backend/npu_utils.cpp new file mode 100644 index 000000000..0395a2666 --- /dev/null +++ b/third_party/ascend/backend/npu_utils.cpp @@ -0,0 +1,135 @@ +#define PY_SSIZE_T_CLEAN +#include + +#include +#include +#include +#include + +#include "experiment/runtime/runtime/rt.h" + +// Use map to differentiate same name functions from different binary +static std::unordered_map registered_names; +static std::unordered_map> func_stubs; + +static std::tuple +registerKernel(const char *name, const void *data, size_t data_size, int shared, + int device, const char *kernel_mode_str) { + rtError_t rtRet; + + rtDevBinary_t devbin; + devbin.data = data; + devbin.length = data_size; + const std::string kernel_mode{kernel_mode_str}; + if (kernel_mode == "aiv") + devbin.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; + else + devbin.magic = RT_DEV_BINARY_MAGIC_ELF; + devbin.version = 0; + + rtRet = rtSetDevice(device); + if (rtRet != RT_ERROR_NONE) { + printf("rtSetDevice failed, 0x%x\n", rtRet); + return {NULL, NULL}; + } + + void *devbinHandle = NULL; + rtRet = rtDevBinaryRegister(&devbin, &devbinHandle); + if (rtRet != RT_ERROR_NONE) { + printf("rtDevBinaryRegister failed, 0x%x\n", rtRet); + return {NULL, NULL}; + } + + std::string stubName = name; + stubName += "_" + std::to_string(registered_names[name]); + registered_names[name]++; + auto registered = func_stubs.emplace(stubName, std::make_unique(0)); + void *func_stub_handle = registered.first->second.get(); + rtRet = rtFunctionRegister(devbinHandle, func_stub_handle, stubName.c_str(), + (void *)name, 0); + if (rtRet != RT_ERROR_NONE) { + printf("rtFunctionRegister failed(stubName = %s), 0x%x\n", stubName.c_str(), + rtRet); + return {NULL, NULL}; + } + + return std::make_tuple(devbinHandle, func_stub_handle); +} + +static PyObject *loadKernelBinary(PyObject *self, PyObject *args) { + const char *name; // kernel name + const char *data; // binary pointer + Py_ssize_t data_size; // binary size + int shared; // shared_memory(meaningless now) + int device; // device ID + const char *kernel_mode; // kernel mode + + if (!PyArg_ParseTuple(args, "ss#iis", &name, &data, &data_size, &shared, + &device, &kernel_mode)) { + return NULL; + } + + auto [module_handle, func_handle] = + registerKernel(name, data, data_size, shared, device, kernel_mode); + + uint64_t mod = reinterpret_cast(module_handle); + uint64_t func = reinterpret_cast(func_handle); + if (PyErr_Occurred()) { + return NULL; + } + + return Py_BuildValue("(KKii)", mod, func, 0, 0); +} + +static PyObject *getArch(PyObject *self, PyObject *args) { + char name[64] = {'\0'}; + + rtError_t rtRet = rtGetSocVersion(name, 64); + + if (rtRet != RT_ERROR_NONE) { + printf("rtGetSocVersion failed, 0x%x", rtRet); + return NULL; + } + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("s", name); +} + +static PyObject *getAiCoreNum(PyObject *self, PyObject *args) { + uint32_t aiCoreCnt; + + rtError_t rtRet = rtGetAiCoreCount(&aiCoreCnt); + + if (rtRet != RT_ERROR_NONE) { + printf("rtGetAiCoreCount failed, 0x%x", rtRet); + return NULL; + } + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("I", aiCoreCnt); +} + +static PyMethodDef NpuUtilsMethods[] = { + {"load_kernel_binary", loadKernelBinary, METH_VARARGS, + "Load NPU kernel binary into NPU driver"}, + {"get_arch", getArch, METH_VARARGS, "Get soc version of NPU"}, + // sentinel + {"get_aicore_num", getAiCoreNum, METH_VARARGS, "Get the number of AI core"}, + {NULL, NULL, 0, NULL}}; + +static PyModuleDef ModuleDef = { + PyModuleDef_HEAD_INIT, "npu_utils", + "Utilities for fetching NPU device info and preparing kernel binary", -1, + NpuUtilsMethods}; + +PyMODINIT_FUNC PyInit_npu_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, NpuUtilsMethods); + return m; +} \ No newline at end of file diff --git a/third_party/ascend/backend/utils.py b/third_party/ascend/backend/utils.py new file mode 100644 index 000000000..71b03326d --- /dev/null +++ b/third_party/ascend/backend/utils.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import functools +import os +import re +import shutil +import subprocess +import sysconfig +from pathlib import Path + +import pybind11 + + +def downgrade_llir(llir): + llir = _downgrade_mem_attrs(llir) + llir = _downgrade_stacksaverestore_intrinsics(llir) + return llir + + +def _downgrade_mem_attrs(llir: str): + memory_pattern = r"memory\([^()]*\)" + + def replace_mem_attr(m): + attrs = m[0][7:-1].split(",") + if len(attrs) == 0: + return "readnone" + loc_map = {"argmem": 1, "inaccessiblemem": 2, "other": 4} + loc_attr = 0 + rw_map = {"readwrite": 3, "write": 2, "read": 1, "none": 0} + rw_attr = 0 + for attr_pair in attrs: + pair = attr_pair.split(":") + assert len(pair) <= 2 + if len(pair) == 1: + rw = rw_map[pair[0].strip()] + loc = loc_map["other"] # all location + else: + rw = rw_map[pair[1].strip()] + loc_str = pair[0].strip() + if loc_str == "argmem" or loc_str == "inaccessiblemem": + loc = loc_map[loc_str] + else: + loc = loc_map["other"] + if rw > 0: + loc_attr = loc_attr | loc + rw_attr = rw_attr | rw + rev_rw_map = {0: "readnone", 1: "readonly", 2: "writeonly"} + if rw_attr in rev_rw_map: + rw_attr_str = rev_rw_map[rw_attr] + else: + rw_attr_str = "" + rev_loc_map = { + 1: "argmemonly", + 2: "inaccessiblememonly", + 3: "inaccessiblemem_or_argmemonly", + } + if loc_attr in rev_loc_map: + loc_attr_str = rev_loc_map[loc_attr] + else: + loc_attr_str = "" + return rw_attr_str + " " + loc_attr_str + + return re.sub(memory_pattern, replace_mem_attr, llir) + + +def _downgrade_stacksaverestore_intrinsics(llir: str): + llir = re.sub(r"llvm\.stacksave\.\w+", "llvm.stacksave", llir) + llir = re.sub(r"llvm\.stackrestore\.\w+", "llvm.stackrestore", llir) + return llir + + +def _get_triton_adapter_opt_path() -> str: + path = os.path.dirname(__file__) + path = os.path.join(path, "triton-adapter-opt") + return path + + +def _get_mlir_path(path: str, *paths) -> str: + root_path = os.getenv("MLIR_ROOT", "") + if root_path == "": + raise EnvironmentError("MLIR_ROOT is not set.") + return os.path.join(root_path, path, *paths) + + +def _get_llvm_path(path: str, *paths) -> str: + root_path = os.getenv("LLVM_ROOT", "") + if root_path == "": + raise EnvironmentError("LLVM_ROOT is not set.") + return os.path.join(root_path, path, *paths) + + +def _get_npucompiler_path() -> str: + npu_compiler_path = shutil.which("bishengir-compile") + if npu_compiler_path is None: + npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "") + if npu_compiler_root is None: + raise EnvironmentError( + "Couldn't find executable bishengir-compile or TRITON_NPU_COMPILER_PATH." + ) + npu_compiler_path = os.path.join(npu_compiler_root, "npuc") + return npu_compiler_path + + +def _get_bisheng_path() -> str: + bisheng_path = shutil.which("bisheng") + if bisheng_path is None: + npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "") + if npu_compiler_root is None: + raise EnvironmentError( + "Couldn't find executable bisheng or TRITON_NPU_COMPILER_PATH" + ) + bisheng_path = os.path.join(npu_compiler_root, "ccec") + return bisheng_path + + +# grep bishengir-compile's option limit-auto-multi-buffer-buffer to check +# if bishengir-compile is a newer version which does not generate kernel_reloc.o +# any more. +def _check_bishengir_api_change() -> bool: + bishengir_path = _get_npucompiler_path() + try: + result = subprocess.run( + [bishengir_path, "--help"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + if result.returncode == 0 and 'limit-auto-multi-buffer-buffer' in result.stdout: + # bishengir-compile is newer version + return True + else: + # bishengir-compile is older version + return False + except Exception as e: + print(f"ERROR: {e}") + return False + + +def _check_bishengir_is_regbased() -> bool: + bishengir_path = _get_npucompiler_path() + try: + result = subprocess.run( + [bishengir_path, "--help"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + if result.returncode == 0 and 'reg-based' in result.stdout: + # bishengir-compile is regbased version + return True + else: + # bishengir-compile is membased version + return False + except Exception as e: + print(f"ERROR: {e}") + return False + + +@functools.lru_cache(None) +def _get_ascend_path() -> str: + path = os.getenv("ASCEND_HOME_PATH", "") + if path == "": + raise EnvironmentError( + "ASCEND_HOME_PATH is not set, source /set_env.sh first" + ) + return Path(path) + + +def _is_ascend_sanitizer_enabled() -> bool: + return os.getenv("TRITON_ENABLE_SANITIZER", "false").lower() in ("true", "1") + + +def _is_debug_line_info_disabled() -> bool: + return os.getenv("TRITON_DISABLE_LINE_INFO", "true").lower() in ("true", "1") + + +def _is_auto_map_parallel_blocks_enabled() -> bool: + if not _enable_unpublished_feature(): + return False + return os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "false").lower() in ("true", "1") + + +def _enable_unpublished_feature() -> bool: + return os.getenv("ENABLE_UNPUBLISHED_FEATURE", "false").lower() in ("true", "1") + + +def _build_npu_ext(obj_name: str, src_path, src_dir, *, kernel_launcher=None) -> str: + suffix = sysconfig.get_config_var("EXT_SUFFIX") + so_path = os.path.join(src_dir, f"{obj_name}{suffix}") + + cxx = os.environ.get("CC") + if cxx is None: + clangxx = shutil.which("clang++") + gxx = shutil.which("g++") + cxx = clangxx if clangxx is not None else gxx + if cxx is None: + raise RuntimeError("Failed to find C++ compiler") + cc_cmd = [cxx, src_path] + # disable all warnings + cc_cmd += [f"-w"] + # find the python library + if hasattr(sysconfig, "get_default_scheme"): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == "posix_local": + scheme = "posix_prefix" + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + cc_cmd += [f"-I{py_include_dir}"] + # device_print.h + cc_cmd += [f"-I{os.path.dirname(os.path.realpath(__file__))}"] + asc_path = _get_ascend_path() + cc_cmd += [ + f"-I{os.path.join(asc_path, 'include')}", + f"-I{os.path.join(asc_path, 'include/experiment')}", + f"-I{os.path.join(asc_path, 'include/experiment/msprof')}", + f"-I{pybind11.get_include()}", + f"-L{os.path.join(asc_path, 'lib64')}", + "-lruntime", + "-lascendcl", + ] + + import torch + import torch_npu + + torch_path = os.path.dirname(os.path.realpath(torch.__file__)) + torch_npu_path = os.path.dirname(os.path.realpath(torch_npu.__file__)) + use_cxx11_abi = _check_cxx11_abi() + cc_cmd += [ + f"-I{os.path.join(torch_path, 'include')}", + f"-I{os.path.join(torch_npu_path, 'include')}", + f"-L{os.path.join(torch_npu_path, 'lib')}", + "-ltorch_npu", + f"-D_GLIBCXX_USE_CXX11_ABI={use_cxx11_abi}", + ] + + cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-o", so_path] + + ret = subprocess.check_call(cc_cmd) + + if ret == 0: + return so_path + else: + raise RuntimeError("Failed to compile " + src_path) + + +def _get_kernel_target(metadata: dict): + if "target" not in metadata: + raise Exception("No target provided!") + sub_target = metadata["target"].arch + assert isinstance(sub_target, str) + if sub_target.startswith("Ascend910B"): + mix_mode = metadata["mix_mode"] + if mix_mode.lower().strip("_").startswith("aiv"): + return "ascend_910b_vec", "c220-vec", "aiv" + elif mix_mode.lower().strip("_").startswith("aic"): + return "ascend_910b_cube", "c220-cube", "aic" + else: + return "ascend_910b", "c220", "mix" + elif sub_target.startswith("Ascend910"): + return "ascend_910", "c100", "mix" + else: + raise NotImplementedError(f"NPU subtarget {sub_target} not supported yet") + + +def _check_cxx11_abi(): + import torch + + return 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 + + +def convert_sigtype_to_int(sigty: str): + MAP_SIGTYPE_TO_INT = { + # Boolean + "i1": 12, # BOOL + # Integer types + "i8": 2, # INT8 + "i16": 6, # INT16 + "i32": 3, # INT32 + "i64": 9, # INT64 + # Unsigned integer types + "u32": 8, # UINT32 + "u64": 10, # UINT64 + # Floating point types + "fp16": 1, # FLOAT16 + "bf16": 27, # DT_BF16 + "fp32": 0, # FLOAT + "fp64": 11, # DOUBLE + } + if sigty not in MAP_SIGTYPE_TO_INT: + raise ValueError(f"Unsupported data type: {sigty}") + + return MAP_SIGTYPE_TO_INT[sigty] diff --git a/third_party/ascend/examples/autotune_cases/01-vector-add.py b/third_party/ascend/examples/autotune_cases/01-vector-add.py new file mode 100644 index 000000000..8136407d9 --- /dev/null +++ b/third_party/ascend/examples/autotune_cases/01-vector-add.py @@ -0,0 +1,83 @@ +""" +Vector Add +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +from triton.testing import do_bench_npu + + +@triton.autotune( + configs=[], + key={"x": "n_elements"}, + split_params={"x": "BLOCK_SIZE"}, + tiling_params={}, + low_dims=["x"], + persistent_reduction=False, + dual_reduction=False, +) +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +def add_torch(x, y): + return x + y + + +def add_autotune(x, y): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel[grid](x, y, output, n_elements) + return output + + +def test_add(size: int): + os.environ["TRITON_BENCH_METHOD"] = ( + "npu" # use torch_npu.profiler to get calculating time + ) + x = torch.rand(size, device="npu") + y = torch.rand(size, device="npu") + + output_torch = add_torch(x, y) + output_triton = add_autotune(x, y) + assert torch.allclose(output_triton, output_torch) + + time_eager = do_bench_npu(lambda: add_torch(x, y)) + time_triton = do_bench_npu(lambda: add_autotune(x, y)) + assert (time_eager / time_triton) >= 0.8 + print(f"Vector Add {size} PASSED!") + + +if __name__ == "__main__": + test_add(98432) diff --git a/third_party/ascend/examples/autotune_cases/02-fused-softmax.py b/third_party/ascend/examples/autotune_cases/02-fused-softmax.py new file mode 100644 index 000000000..f9c6cb2a5 --- /dev/null +++ b/third_party/ascend/examples/autotune_cases/02-fused-softmax.py @@ -0,0 +1,100 @@ +""" +Fused Softmax +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +from triton.testing import do_bench_npu + + +@triton.autotune( + configs=[], + key={"x": "n_rows", "y": "n_cols"}, + split_params={"x": "XBLOCK"}, + tiling_params={"x": "XBLOCK_SUB"}, + low_dims=["y"], + persistent_reduction=False, + dual_reduction=False, +) +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, +): + # starting row of the program + row_start = tl.program_id(0) * XBLOCK + for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): + # The stride represents how much we need to increase the pointer to advance 1 row + row_offsets = row_start + row_idx + tl.arange(0, XBLOCK_SUB)[:, None] + col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] + xmask = row_offsets < n_rows + ymask = col_offsets < n_cols + mask = xmask & ymask + input_ptrs = input_ptr + (row_offsets * input_row_stride + col_offsets) + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to( + XBLOCK_SUB, BLOCK_SIZE + ) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = ( + tl.sum(numerator, axis=1) + .reshape(XBLOCK_SUB, 1) + .broadcast_to(XBLOCK_SUB, BLOCK_SIZE) + ) + softmax_output = numerator / denominator + # Write back output to DRAM + output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) + tl.store(output_ptrs, softmax_output, mask=mask) + + +def softmax_torch(x): + return torch.softmax(x, axis=-1) + + +def softmax_autotune(x): + n_rows, n_cols = x.shape + BLOCK_SIZE = n_cols + + # Allocate output + y = torch.empty_like(x) + grid = lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1) + # Create a number of persistent programs. + softmax_kernel[grid]( + y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE + ) + return y + + +def test_softmax(shape, dtype): + os.environ["TRITON_BENCH_METHOD"] = ( + "npu" # use torch_npu.profiler to get calculating time + ) + x = torch.randn(shape, dtype=dtype, device="npu") + + y_torch = softmax_torch(x) + y_triton = softmax_autotune(x) + assert torch.allclose(y_triton, y_torch) + + time_eager = do_bench_npu(lambda: softmax_torch(x)) + time_triton = do_bench_npu(lambda: softmax_autotune(x)) + assert (time_eager / time_triton) >= 0.8 + print(f"Fused Softmax {shape} {dtype} PASSED!") + + +if __name__ == "__main__": + test_softmax((16896, 1024), torch.float32) diff --git a/third_party/ascend/examples/autotune_cases/03-layer-norm.py b/third_party/ascend/examples/autotune_cases/03-layer-norm.py new file mode 100644 index 000000000..5125957b4 --- /dev/null +++ b/third_party/ascend/examples/autotune_cases/03-layer-norm.py @@ -0,0 +1,128 @@ +""" +Layer Normalization +============= +""" + +import os + +import torch +import torch_npu +import triton +import triton.language as tl +from triton.testing import do_bench_npu + + +@triton.autotune( + configs=[], + key={"x": "M", "y": "N"}, + split_params={"x": "XBLOCK_SIZE"}, + tiling_params={"y": "RBLOCK_SIZE"}, + low_dims=["y"], + persistent_reduction=False, + dual_reduction=False, +) +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, + M, # number of columns in X + eps, # epsilon to avoid division by zero + XBLOCK_SIZE: tl.constexpr, + RBLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row_begin = tl.program_id(0) * XBLOCK_SIZE + row_idx = row_begin + tl.arange(0, XBLOCK_SIZE) + row_mask = row_idx < M + row_offsets = row_idx[:, None] * stride + # Compute mean + _mean = tl.zeros((XBLOCK_SIZE, RBLOCK_SIZE), dtype=tl.float32) + for off in range(0, N, RBLOCK_SIZE): + col_idx = off + tl.arange(0, RBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:, None] & col_mask[None, :] + a = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to( + tl.float32 + ) + _mean += a + mean = tl.sum(_mean, axis=1, keep_dims=True) / N + # Compute variance + _var = tl.zeros((XBLOCK_SIZE, RBLOCK_SIZE), dtype=tl.float32) + for off in range(0, N, RBLOCK_SIZE): + col_idx = off + tl.arange(0, RBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:, None] & col_mask[None, :] + x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to( + tl.float32 + ) + x = tl.where(mask, x - mean, 0.0) + _var += x * x + var = tl.sum(_var, axis=1, keep_dims=True) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row_idx[:, None], mean, mask=row_mask[:, None]) + tl.store(Rstd + row_idx[:, None], rstd, mask=row_mask[:, None]) + # Normalize and apply linear transformation + for off in range(0, N, RBLOCK_SIZE): + col_idx = off + tl.arange(0, RBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:, None] & col_mask[None, :] + w = tl.load(W + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) + b = tl.load(B + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) + x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to( + tl.float32 + ) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + row_offsets + col_idx[None, :], y, mask=mask) + + +def layer_norm_torch(args): + x, w_shape, weight, bias, eps, dtype = args + return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + +def layer_norm_autotune(args): + x, weight, bias, eps = args + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M,), dtype=torch.float32, device=x.device) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + + grid = lambda meta: (triton.cdiv(M, meta["XBLOCK_SIZE"]), 1, 1) + # enqueue kernel + _layer_norm_fwd_fused[grid]( # + x_arg, y, weight, bias, mean, rstd, x_arg.stride(0), N, M, eps # + ) + return y + + +def test_layer_norm(shape, dtype, eps=1e-5): + os.environ["TRITON_BENCH_METHOD"] = ( + "npu" # use torch_npu.profiler to get calculating time + ) + M, N = shape + device = "npu" + x_shape = shape + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device=device) + bias = torch.rand(w_shape, dtype=dtype, device=device) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + y_torch = layer_norm_torch((x, w_shape, weight, bias, eps, dtype)) + y_triton = layer_norm_autotune((x, weight, bias, eps)) + assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) + print(f"Layer Normalization {M},{N} {dtype} PASSED!") + + +if __name__ == "__main__": + test_layer_norm((128, 128), torch.float16) diff --git a/third_party/ascend/examples/benchmark_cases/layernorm_perf.py b/third_party/ascend/examples/benchmark_cases/layernorm_perf.py new file mode 100644 index 000000000..eedeb0aec --- /dev/null +++ b/third_party/ascend/examples/benchmark_cases/layernorm_perf.py @@ -0,0 +1,450 @@ +""" +Layer Normalization +==================== +In this tutorial, you will write a high-performance layer normalization +kernel that runs faster than the PyTorch implementation. + +In doing so, you will learn about: + +* Implementing backward pass in Triton. + +* Implementing parallel reduction in Triton. + +""" + +# %% +# Motivations +# ----------- +# +# The *LayerNorm* operator was first introduced in [BA2016]_ as a way to improve the performance +# of sequential models (e.g., Transformers) or neural networks with small batch size. +# It takes a vector :math:`x` as input and produces a vector :math:`y` of the same shape as output. +# The normalization is performed by subtracting the mean and dividing by the standard deviation of :math:`x`. +# After the normalization, a learnable linear transformation with weights :math:`w` and biases :math:`b` is applied. +# The forward pass can be expressed as follows: +# +# .. math:: +# y = \frac{ x - \text{E}[x] }{ \sqrt{\text{Var}(x) + \epsilon} } * w + b +# +# where :math:`\epsilon` is a small constant added to the denominator for numerical stability. +# Let’s first take a look at the forward pass implementation. + +import torch +import torch_npu + +import triton +import triton.language as tl + +import time + +HAS_APEX = False +DEVICE = "npu" + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, + M, # number of columns in X + eps, # epsilon to avoid division by zero + XBLOCK_SIZE: tl.constexpr, + RBLOCK_SIZE: tl.constexpr +): + # Map the program id to the row of X and Y it should compute. + row_begin = tl.program_id(0) * RBLOCK_SIZE + row_idx = row_begin + tl.arange(0,RBLOCK_SIZE) + row_mask = row_idx < M + row_offsets = row_idx[:,None]*stride + # Compute mean + + _mean = tl.zeros((RBLOCK_SIZE, XBLOCK_SIZE), dtype=tl.float32) + for off in range(0, N, XBLOCK_SIZE): + col_idx = off + tl.arange(0, XBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:,None] & col_mask[None,:] + a = tl.load(X + row_offsets + col_idx[None,:], mask=mask, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=1, keep_dims = True) / N + + # Compute variance + _var = tl.zeros((RBLOCK_SIZE, XBLOCK_SIZE), dtype=tl.float32) + for off in range(0, N, XBLOCK_SIZE): + col_idx = off + tl.arange(0, XBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:,None] & col_mask[None,:] + x = tl.load(X + row_offsets + col_idx[None,:], mask=mask, other=0.).to(tl.float32) + x = tl.where(mask, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=1, keep_dims=True) / N + + rstd = 1 / tl.sqrt(var + eps) + + # Write mean / rstd + tl.store(Mean + row_idx[:,None], mean, mask = row_mask[:,None]) + tl.store(Rstd + row_idx[:,None], rstd, mask = row_mask[:,None]) + # mean = mean.broadcast_to((RBLOCK_SIZE, XBLOCK_SIZE)) + # rstd = rstd.broadcast_to((RBLOCK_SIZE, XBLOCK_SIZE)) + # Normalize and apply linear transformation + for off in range(0, N, XBLOCK_SIZE): + col_idx = off + tl.arange(0, XBLOCK_SIZE) + col_mask = col_idx < N + mask = row_mask[:,None] & col_mask[None,:] + w = tl.load(W + col_idx, mask=col_mask).reshape((1,XBLOCK_SIZE)) + b = tl.load(B + col_idx, mask=col_mask).reshape((1,XBLOCK_SIZE)) + x = tl.load(X + row_offsets + col_idx[None,:], mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + row_offsets + col_idx[None,:], y, mask=mask) + + +# %% +# Backward pass +# ------------- +# +# The backward pass for the layer normalization operator is a bit more involved than the forward pass. +# Let :math:`\hat{x}` be the normalized inputs :math:`\frac{ x - \text{E}[x] }{ \sqrt{\text{Var}(x) + \epsilon} }` before the linear transformation, +# the Vector-Jacobian Products (VJP) :math:`\nabla_{x}` of :math:`x` are given by: +# +# .. math:: +# \nabla_{x} = \frac{1}{\sigma}\Big( \nabla_{y} \odot w - \underbrace{ \big( \frac{1}{N} \hat{x} \cdot (\nabla_{y} \odot w) \big) }_{c_1} \odot \hat{x} - \underbrace{ \frac{1}{N} \nabla_{y} \cdot w }_{c_2} \Big) +# +# where :math:`\odot` denotes the element-wise multiplication, :math:`\cdot` denotes the dot product, and :math:`\sigma` is the standard deviation. +# :math:`c_1` and :math:`c_2` are intermediate constants that improve the readability of the following implementation. +# +# For the weights :math:`w` and biases :math:`b`, the VJPs :math:`\nabla_{w}` and :math:`\nabla_{b}` are more straightforward: +# +# .. math:: +# \nabla_{w} = \nabla_{y} \odot \hat{x} \quad \text{and} \quad \nabla_{b} = \nabla_{y} +# +# Since the same weights :math:`w` and biases :math:`b` are used for all rows in the same batch, their gradients need to sum up. +# To perform this step efficiently, we use a parallel reduction strategy: each kernel instance accumulates +# partial :math:`\nabla_{w}` and :math:`\nabla_{b}` across certain rows into one of :math:`\text{GROUP_SIZE_M}` independent buffers. +# These buffers stay in the L2 cache and then are further reduced by another function to compute the actual :math:`\nabla_{w}` and :math:`\nabla_{b}`. +# +# Let the number of input rows :math:`M = 4` and :math:`\text{GROUP_SIZE_M} = 2`, +# here's a diagram of the parallel reduction strategy for :math:`\nabla_{w}` (:math:`\nabla_{b}` is omitted for brevity): +# +# .. image:: parallel_reduction.png +# +# In Stage 1, the rows of X that have the same color share the same buffer and thus a lock is used to ensure that only one kernel instance writes to the buffer at a time. +# In Stage 2, the buffers are further reduced to compute the final :math:`\nabla_{w}` and :math:`\nabla_{b}`. +# In the following implementation, Stage 1 is implemented by the function :code:`_layer_norm_bwd_dx_fused` and Stage 2 is implemented by the function :code:`_layer_norm_bwd_dwdb`. + + +@triton.jit +def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient + DY, # pointer to the output gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + X, # pointer to the input + W, # pointer to the weights + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + Lock, # pointer to the lock + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + # Map the program id to the elements of X, DX, and DY it should compute. + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N + X += row * stride + DY += row * stride + DX += row * stride + # Offset locks and weights/biases gradient pointer for parallel reduction + lock_id = row % GROUP_SIZE_M + Lock += lock_id + Count = Lock + GROUP_SIZE_M + DW = DW + lock_id * N + cols + DB = DB + lock_id * N + cols + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd + wdy = w * dy + xhat = tl.where(mask, xhat, 0.) + wdy = tl.where(mask, wdy, 0.) + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + # Accumulate partial sums for dw/db + partial_dw = (dy * xhat).to(w.dtype) + partial_db = (dy).to(w.dtype) + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + # First store doesn't accumulate + if count == 0: + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(DW, mask=mask) + partial_db += tl.load(DB, mask=mask) + tl.store(DW, partial_dw, mask=mask) + tl.store(DB, partial_db, mask=mask) + # Release the lock + tl.atomic_xchg(Lock, 0) + + +@triton.jit +def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + FINAL_DW, # pointer to the weights gradient + FINAL_DB, # pointer to the biases gradient + M, # GROUP_SIZE_M + N, # number of columns + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + # Map the program id to the elements of DW and DB it should compute. + pid = tl.program_id(0) + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Iterate through the rows of DW and DB to sum the partial sums. + for i in range(0, M, BLOCK_SIZE_M): + rows = i + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.) + db += tl.load(DB + offs, mask=mask, other=0.) + # Write the final sum to the output. + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) + tl.store(FINAL_DB + cols, sum_db, mask=cols < N) + + +# %% +# Benchmark +# --------- +# +# We can now compare the performance of our kernel against that of PyTorch. +# Here we focus on inputs that have Less than 64KB per feature. +# Specifically, one can set :code:`'mode': 'backward'` to benchmark the backward pass. + + +device = torch.npu.current_device() +stream = torch.npu.current_stream(device).npu_stream +kernels = {} + +class LayerNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, normalized_shape, weight, bias, eps): + + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + + MAX_FUSED_SIZE = 65536 // x.element_size() + XBLOCK_SIZE = 256 + RBLOCK_SIZE = 32 + NUM_CORE = (M -1) // RBLOCK_SIZE + 1 + num_warps = min(max((N - 1) // XBLOCK_SIZE + 1, 1), 8) + # enqueue kernel + + kernel, num_programs = kernels.get(XBLOCK_SIZE^RBLOCK_SIZE, (None, NUM_CORE)) + if kernel is None: + kernel = _layer_norm_fwd_fused.warmup( x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, M, eps, # + XBLOCK_SIZE = XBLOCK_SIZE, + RBLOCK_SIZE = RBLOCK_SIZE, + grid=(NUM_CORE,)) + kernel._init_handles() + kernels[XBLOCK_SIZE^RBLOCK_SIZE] = (kernel, num_programs) + + kernel[(num_programs,1,1 )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, M, eps, # + stream=stream, + ) + + # _layer_norm_fwd_fused[(NUM_CORE, )]( # + # x_arg, y, weight, bias, mean, rstd, # + # x_arg.stride(0), N, M, eps, # + # XBLOCK_SIZE = XBLOCK_SIZE, + # RBLOCK_SIZE = RBLOCK_SIZE, + # num_warps=num_warps, + # num_ctas=1) + ctx.save_for_backward(x, weight, bias, mean, rstd) + # ctx.BLOCK_SIZE = XBLOCK_SIZE + ctx.num_warps = num_warps + ctx.eps = eps + return y + + @staticmethod + def backward(ctx, dy): + x, w, b, m, v = ctx.saved_tensors + # heuristics for amount of parallel reduction stream for DW/DB + N = w.shape[0] + GROUP_SIZE_M = 64 + if N <= 8192: GROUP_SIZE_M = 96 + if N <= 4096: GROUP_SIZE_M = 128 + if N <= 1024: GROUP_SIZE_M = 256 + # allocate output + locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device) + _dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) + _db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) + dw = torch.empty((N, ), dtype=w.dtype, device=w.device) + db = torch.empty((N, ), dtype=w.dtype, device=w.device) + dx = torch.empty_like(dy) + # enqueue kernel using forward pass heuristics + # also compute partial sums for DW and DB + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + _layer_norm_bwd_dx_fused[(M, )]( # + dx, dy, _dw, _db, x, w, m, v, locks, # + x_arg.stride(0), N, # + BLOCK_SIZE_N=ctx.BLOCK_SIZE, # + GROUP_SIZE_M=GROUP_SIZE_M, # + num_warps=ctx.num_warps) + grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + # accumulate partial sums in separate kernel + _layer_norm_bwd_dwdb[grid]( + _dw, _db, dw, db, min(GROUP_SIZE_M, M), N, # + BLOCK_SIZE_M=32, # + BLOCK_SIZE_N=128, num_ctas=1) + return dx, None, dw, db, None + + +layer_norm = LayerNorm.apply + + +def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], + x_vals=[512 * i for i in range(20, 30)], + line_arg='provider', + line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []), + line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []), + styles=[('blue', '-'), ('green', '-'), ('orange', '-')], + ylabel='GB/s', + plot_name='layer-norm-backward', + args={'M': 3072, 'dtype': torch.float16, 'mode': 'forward'}, # 4096 better + )) +def bench_layer_norm(M, N, dtype, provider, mode='forward', eps=1e-5, device=DEVICE): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + quantiles = [0.5, 0.2, 0.8] + + def y_fwd(): + + if provider == "triton": + return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + if provider == "torch": + return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + if provider == "apex": + apex_layer_norm = (apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)) + return apex_layer_norm(x) # noqa: F811, E704 + + # forward pass + if mode == 'forward': + gbps = lambda ms: ms*1000 + ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) + # backward pass + if mode == 'backward': + y = y_fwd() + gbps = lambda ms: ms*1000 + ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, + grad_to_none=[x], rep=500) + return gbps(ms), gbps(max_ms), gbps(min_ms) + +def benchmark_test(fn, fn_triton, args =(), name="gen_fn", times=100, repeat=10): + print(f"--------------------benchmark_{name} for {times * repeat} times--------------------") + stream = torch.npu.current_stream() + # warm_up + stream.synchronize() + for _ in range(10) : + fn_triton(*args) + stream.synchronize() + + start = time.perf_counter() + for _ in range(times * repeat) : + fn_triton(*args) + stream.synchronize() + end = time.perf_counter() + + time_compiled = (end - start) / (times * repeat) + time_compiled *= 1000000 + print(f"time_triton:{time_compiled:.6f}") + + + print(f"Runing eager {name} for {times * repeat} times") + + # warm_up + stream.synchronize() + for _ in range(10) : + std = fn(*args) + stream.synchronize() + + start = time.perf_counter() + for _ in range(times * repeat) : + std = fn(*args) + stream.synchronize() + end = time.perf_counter() + time_eager = (end - start) / (times * repeat) + time_eager *= 1000000 + print(f"time_eager:{time_eager:.6f}") + + accelerated = (time_eager - time_compiled)/time_compiled*100 + print(f"Accelerated: {accelerated:.4f}% eager takes {time_eager:.3f} us, triton takes {time_compiled:.3f} us") + + return accelerated, time_eager, time_compiled + +test_layer_norm(1151, 8192, torch.float16) + +M = 2048 +N = 8192 # 12288 12800 13312 13000 +x_shape = (M, N) +w_shape = (x_shape[-1], ) +weight = torch.rand(w_shape, dtype=torch.float16, device='npu', requires_grad=True) +bias = torch.rand(w_shape, dtype=torch.float16, device='npu', requires_grad=True) +x = -2.3 + 0.5 * torch.randn(x_shape, dtype=torch.float16, device='npu') +eps = 1e-5 +benchmark_test(torch.nn.functional.layer_norm,layer_norm,args=(x, w_shape, weight, bias, eps)) + +# %% +# References +# ---------- +# +# .. [BA2016] Jimmy Lei Ba and Jamie Ryan Kiros and Geoffrey E. Hinton, "Layer Normalization", Arxiv 2016 diff --git a/third_party/ascend/examples/benchmark_cases/softmax_perf.py b/third_party/ascend/examples/benchmark_cases/softmax_perf.py new file mode 100644 index 000000000..a282d7e15 --- /dev/null +++ b/third_party/ascend/examples/benchmark_cases/softmax_perf.py @@ -0,0 +1,268 @@ +""" +Fused Softmax +============= + +In this tutorial, you will write a fused softmax operation that is significantly faster +than PyTorch's native op for a particular class of matrices: those whose rows can fit in +the NPU's SRAM. + +In doing so, you will learn about: + +* The benefits of kernel fusion for bandwidth-bound operations. + +* Reduction operators in Triton. + +""" + +# %% +# Motivations +# ----------- +# +# Custom NPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. +# Let us consider instead the case of a simple (numerically stabilized) softmax operation: + +import torch +import torch_npu +import triton +import triton.language as tl +from triton.runtime import driver +import time + +def naive_softmax(x): + """Compute row-wise softmax of X using native pytorch + + We subtract the maximum element in order to avoid overflows. Softmax is invariant to + this shift. + """ + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + + +# %% +# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` +# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. +# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads +# X once and does all the necessary computations on-chip. +# Doing so would require reading and writing back only :math:`MN` bytes, so we could +# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). +# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically +# but, as we will see later, it is still far from ideal. + +# %% +# Compute Kernel +# -------------- +# +# Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, +# normalizes it and writes back the result to the output Y. +# +# Note that one important limitation of Triton is that each block must have a +# power-of-two number of elements, so we need to internally "pad" each row and guard the +# memory operations properly if we want to handle any possible input shapes: + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, + XBLOCK:tl.constexpr, num_stages: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) * XBLOCK + XBLOCK_SUB : tl.constexpr = 8 + #for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): + for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB) : + # The stride represents how much we need to increase the pointer to advance 1 row + row_offsets = row_start + row_idx + tl.arange(0, XBLOCK_SUB)[:,None] + col_offsets = tl.arange(0, BLOCK_SIZE)[None,:] + xmask = (row_offsets < n_rows) + ymask = (col_offsets < n_cols) + mask = xmask & ymask + input_ptrs = input_ptr + (row_offsets * input_row_stride + col_offsets ) + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB,1).broadcast_to(XBLOCK_SUB,BLOCK_SIZE) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=1).reshape(XBLOCK_SUB,1).broadcast_to(XBLOCK_SUB,BLOCK_SIZE) + softmax_output = numerator / denominator + # Write back output to DRAM + output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets ) + tl.store(output_ptrs, softmax_output, mask=mask) + + +# %% +# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. +# NUM_SM = properties["multiprocessor_count"] +# NUM_REGS = properties["max_num_regs"] +# SIZE_SMEM = properties["max_shared_mem"] +# WARP_SIZE = properties["warpSize"] +target = triton.runtime.driver.active.get_current_target() +kernels = {} + +device = torch.npu.current_device() +stream = torch.npu.current_stream(device).npu_stream + +def softmax(x): + n_rows, n_cols = x.shape + + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` + num_programs = 32 + + XBLOCK = (n_rows + num_programs -1) // num_programs + BLOCK_SIZE = n_cols + # Another trick we can use is to ask the compiler to use more threads per row by + # increasing the number of warps (`num_warps`) over which each row is distributed. + # You will see in the next tutorial how to auto-tune this value in a more natural + # way so you don't have to come up with manual heuristics yourself. + num_warps = 8 + + # Number of software piepling stages. + #num_stages = 4 if SIZE_SMEM > 200000 else 2 + num_stages =4 + + # Allocate output + y = torch.empty_like(x) + + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, num_programs)) + if kernel is None: + kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, + XBLOCK=XBLOCK, num_stages=num_stages, num_warps=num_warps, grid=(32, )) + kernel._init_handles() + # n_regs = kernel.n_regs + # size_smem = kernel.metadata.shared + # occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) + # occupancy = min(occupancy, SIZE_SMEM // size_smem) + # num_programs = NUM_SM * occupancy + kernels[BLOCK_SIZE] = (kernel, num_programs) + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(32, 1, 1)]( + y, + x, + x.stride(0), + y.stride(0), + n_rows, + n_cols, + stream=stream + ) + return y + + +# %% +# Unit Test +# --------- + +# %% +# We make sure that we test our kernel on a matrix with an irregular number of rows and columns. +# This will allow us to verify that our padding mechanism works. + +def torch_softmax(x): + return torch.softmax(x, axis=-1) + +torch.manual_seed(0) +# x = torch.randn(1823, 781, device='npu') +x = torch.randn(4096, 1024, device='npu') +y_triton = softmax(x) +y_torch = torch_softmax(x) +assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) +#torch.testing.assert_close(y_triton, y_torch, rtol=1e-3, atol=1e-3) +# %% +# As expected, the results are identical. +def benchmark_test(fn, fn_triton, args =(), name="gen_fn", times=100, repeat=10): + print(f"--------------------benchmark_{name} for {times * repeat} times--------------------") + stream = torch.npu.current_stream() + # warm_up + stream.synchronize() + for _ in range(10) : + fn_triton(args) + stream.synchronize() + + start = time.perf_counter() + for _ in range(times * repeat) : + fn_triton(args) + stream.synchronize() + end = time.perf_counter() + + time_compiled = (end - start) / (times * repeat) + time_compiled *= 1000000 + print(f"time_triton:{time_compiled:.6f}") + + + print(f"Runing eager {name} for {times * repeat} times") + + # warm_up + stream.synchronize() + for _ in range(10) : + std = fn(args) + stream.synchronize() + + start = time.perf_counter() + for _ in range(times * repeat) : + std = fn(args) + stream.synchronize() + end = time.perf_counter() + time_eager = (end - start) / (times * repeat) + time_eager *= 1000000 + print(f"time_eager:{time_eager:.6f}") + + accelerated = (time_eager - time_compiled)/time_compiled*100 + print(f"Accelerated: {accelerated:.4f}% eager takes {time_eager:.3f} us, triton takes {time_compiled:.3f} us") + + return accelerated, time_eager, time_compiled + +# x = torch.randn(4096, 1024, device='npu') +benchmark_test(torch_softmax,softmax,args=x) +# %% +# Benchmark +# --------- +# +# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. +# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. + + +# @triton.testing.perf_report( +# triton.testing.Benchmark( +# x_names=['N'], # argument names to use as an x-axis for the plot +# x_vals=[128 * i for i in range(2, 8)], # different possible values for `x_name` +# line_arg='provider', # argument name whose value corresponds to a different line in the plot +# line_vals=['triton', 'torch'], # possible values for `line_arg`` +# line_names=[ +# "Triton", +# "Torch", +# ], # label name for the lines +# styles=[('blue', '-'), ('green', '-')], # line styles +# ylabel="GB/s", # label name for the y-axis +# plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. +# args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` +# )) +# def benchmark(M, N, provider): +# x = torch.randn(M, N, device='npu', dtype=torch.float32) +# #stream = torch.npu.Stream() +# #torch.npu.set_stream(stream) + +# if provider == 'torch': +# ms = triton.testing.do_bench(lambda: torch_softmax(x)) +# if provider == 'triton': +# ms = triton.testing.do_bench(lambda: softmax(x)) +# # gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) +# gbps = lambda ms: ms*1000 +# return gbps(ms) + + +# benchmark.run(show_plots=True, print_data=True) + +# %% +# In the above plot, we can see that: +# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. +# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. +# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. diff --git a/third_party/ascend/examples/conftest.py b/third_party/ascend/examples/conftest.py new file mode 100644 index 000000000..7c6b7d64e --- /dev/null +++ b/third_party/ascend/examples/conftest.py @@ -0,0 +1,20 @@ +import os +import pytest + + +def pytest_configure(config): + # 仅在工作节点设置设备ID + if hasattr(config, "workerinput"): + worker_id = config.workerinput["workerid"] + device_id = int(worker_id.replace("gw", "")) + os.environ["ASCEND_DEVICE_ID"] = str(device_id) + print(f"\n>> Worker {worker_id} using NPU device {device_id}") + + +# 可选:设备初始化逻辑 +@pytest.fixture(scope="session", autouse=True) +def init_npu_device(): + if "ASCEND_DEVICE_ID" in os.environ: + device_id = os.environ["ASCEND_DEVICE_ID"] + # 在此处添加设备初始化代码 + print(f"Initializing NPU device {device_id}") \ No newline at end of file diff --git a/third_party/ascend/examples/flaggems_cases/op_test_run.sh b/third_party/ascend/examples/flaggems_cases/op_test_run.sh new file mode 100644 index 000000000..2c94d72ab --- /dev/null +++ b/third_party/ascend/examples/flaggems_cases/op_test_run.sh @@ -0,0 +1,417 @@ +#!/bin/bash + +# 获取传入参数 +param="$1" +input_ops="$2" +device_count="${3:-1}" # 默认使用1个设备 +threads_per_device="${4:-64}" # 每个设备线程数,默认64 + +# 定义路径 +DIR_TESTS="tests" +DIR_BENCHMARK="benchmark" +DAILY_LOG_DIR="/home/daily_log" +TIMESTAMP=$(date +"%Y%m%d") +LOG_ARCHIVE="test_flaggems_logs_${TIMESTAMP}.tar.gz" +SUMMARY_FILE="${WORKSPACE}/ascend/examples/summary.txt" # 新增:统计信息文件 + +# 检查日志目录 +mkdir -p "$DAILY_LOG_DIR" || { echo "无法创建日志目录 $DAILY_LOG_DIR"; exit 1; } + +# 中央计数器文件定义 +COUNTER_FILE=$(mktemp) +LOCK_FILE="/tmp/op_test_run.lock" +touch $LOCK_FILE + +# ===== 修改:改进的统计结果收集机制 ===== +# 使用文件存储统计结果 +STATS_DIR=$(mktemp -d) +# 初始化设备统计文件 +for ((device_id=0; device_id < device_count; device_id++)); do + stats_file="${STATS_DIR}/device_${device_id}.stats" + echo "success=0" > "$stats_file" + echo "failure=0" >> "$stats_file" + echo "skipped=0" >> "$stats_file" + echo "error=0" >> "$stats_file" +done + +# 原子更新统计 +record_stats() { + local device_id=$1 + local status=$2 # success/failure/skipped/error + local stats_file="${STATS_DIR}/device_${device_id}.stats" + + ( + flock -x 20 + # 读取当前值 + current=$(grep "^${status}=" "$stats_file" | cut -d= -f2) + # 更新值 + new_value=$((current + 1)) + # 替换文件中的值 + sed -i "s/^${status}=.*/${status}=${new_value}/" "$stats_file" + ) 20>"${stats_file}.lock" +} + +# 任务队列管理函数 +init_task_queue() { + local -n arr_ref=$1 + TASK_FILE=$(mktemp) + printf "%s\n" "${arr_ref[@]}" > "$TASK_FILE" + echo 0 > "$TASK_FILE.counter" + echo "${#arr_ref[@]}" > "$COUNTER_FILE.total" + echo 0 > "$COUNTER_FILE.completed" +} + +get_next_task() { + ( + # 文件锁保证原子操作 + flock -x 9 + counter=$(< $TASK_FILE.counter) + total_tasks=$(wc -l < $TASK_FILE) + + if (( counter >= total_tasks )); then + echo "" + return + fi + + task_name=$(sed -n "$((counter+1))p" $TASK_FILE) + echo $((counter+1)) > "$TASK_FILE.counter" + echo "$task_name" + ) 9> "$TASK_FILE.lock" +} + +# 原子更新完成计数器 +update_progress() { + ( + flock -x 11 + local current=$(< $COUNTER_FILE.completed) + echo $((current + 1)) > $COUNTER_FILE.completed + echo $((current + 1)) + ) 11> $LOCK_FILE +} + +# 获取进度信息 +get_progress() { + ( + flock -s 11 # 共享锁(只读) + completed=$(< $COUNTER_FILE.completed) + total=$(< $COUNTER_FILE.total) + echo "$completed $total" + ) 11> $LOCK_FILE +} + +cleanup_tasks() { + rm -f "$TASK_FILE" "$TASK_FILE.counter" "$TASK_FILE.lock" $LOCK_FILE $COUNTER_FILE* +} + +# 算子列表定义 +OPS=("abs" "add" "addmm" "all" "amax" "argmax" "bitwise_and" "bitwise_not" "bitwise_or" "bmm" \ +"cos" "CrossEntryLoss" "div" "dropout" "eq" "exp" "fill" "ge" "gelu" "group_norm" "gt" "isinf" \ +"isnan" "rsub" "le" "linear" "log_softmax" "lt" "max" "mean" "min" "mm" "mul" "mv" \ +"native_dropout" "ne" "neg" "pow" "prod" "reciprocal" "relu" "rsqrt" "sigmoid" "silu" \ +"sin" "softmax" "sub" "sum" "tanh" "triu") + +total_ops=${#OPS[@]} +echo "======================================" +echo "测试算子列表: ${OPS[@]}" +echo "算子总数: $total_ops" +echo "使用设备数量: $device_count" +echo "每设备线程数: $threads_per_device" +echo "======================================" + +# 初始化性能计数器 - 修复开始时间显示问题 +start_time=$(date +%s) # 使用Unix时间戳 + +# 线程执行函数 - 正确性测试 +run_tests_thread() { + local device_id=$1 + local thread_id=$2 + local device_log_dir=$3 + local thread_log_dir="$device_log_dir/thread_${thread_id}" + mkdir -p "$thread_log_dir" + + while true; do + task_name=$(get_next_task) + [[ -z "$task_name" ]] && break + + echo "[设备 $device_id-线程 $thread_id] 正在执行: pytest -m $task_name --ref cpu -sv" + log_file="${thread_log_dir}/result_${task_name}.log" + + # 执行正确性测试并记录时间 + start_op=$(date +%s) + pytest -m $task_name --dist=loadfile --ref cpu -sv &> "$log_file" + exit_code=$? + duration=$(( $(date +%s) - start_op )) + + # 根据退出码记录不同状态 + case $exit_code in + 0) + status="success" + ;; + 1) + status="failure" + ;; + 2) # pytest跳过用例的退出码 + status="skipped" + ;; + *) + status="error" + ;; + esac + + # 记录统计结果 + record_stats $device_id $status + + # 原子更新完成计数 + new_completed=$(update_progress) + + # 获取最新进度状态 + read completed total < <(get_progress) + progress=$(( completed * 100 / total )) + + # 输出结果 + if [ $exit_code -ne 0 ]; then + echo "[错误] [$device_id-$thread_id] $task_name 失败! (用时 ${duration}s, 进度: $completed/$total)" + else + echo "[成功] [$device_id-$thread_id] $task_name 完成! (用时 ${duration}s, 进度: $completed/$total)" + fi + done +} + +# 线程执行函数 - 性能测试 +run_benchmark_thread() { + local device_id=$1 + local thread_id=$2 + local device_log_dir=$3 + local thread_log_dir="$device_log_dir/thread_${thread_id}" + mkdir -p "$thread_log_dir" + + while true; do + task_name=$(get_next_task) + [[ -z "$task_name" ]] && break + + echo "[设备 $device_id-线程 $thread_id] 正在执行: pytest -m $task_name --level core --record log" + log_file="${thread_log_dir}/benchmark_${task_name}.log" + perf_file="${thread_log_dir}/perf_${task_name}.log" + + # 执行性能测试并记录时间 + start_op=$(date +%s) + pytest -m $task_name --level core --record "$perf_file" &> "$log_file" + exit_code=$? + duration=$(( $(date +%s) - start_op )) + + # 根据退出码记录不同状态 + case $exit_code in + 0) + status="success" + ;; + 1) + status="failure" + ;; + 2) # pytest跳过用例的退出码 + status="skipped" + ;; + *) + status="error" + ;; + esac + + # 记录统计结果 + record_stats $device_id $status + + # 原子更新完成计数 + new_completed=$(update_progress) + + # 获取最新进度状态 + read completed total < <(get_progress) + progress=$(( completed * 100 / total )) + + # 输出结果 + if [ $exit_code -ne 0 ]; then + echo "[错误] [$device_id-$thread_id] $task_name 性能测试失败! (用时 ${duration}s, 进度: $completed/$total)" + else + echo "[成功] [$device_id-$thread_id] $task_name 性能测试完成! (用时 ${duration}s, 进度: $completed/$total)" + fi + done +} + +# 设备主函数 +run_device() { + local device_id=$1 + local mode=$2 + local device_log_dir="device_${device_id}_logs" + mkdir -p "$device_log_dir" + + # 创建设备内的线程池 + for ((thread_id=0; thread_id < threads_per_device; thread_id++)); do + if [ "$mode" == "tests" ]; then + run_tests_thread $device_id $thread_id "$device_log_dir" & + elif [ "$mode" == "benchmark" ]; then + run_benchmark_thread $device_id $thread_id "$device_log_dir" & + fi + done + + # 等待设备内所有线程完成 + wait + echo "======== 设备 $device_id 上所有任务完成 ========" +} + +# 根据参数执行测试 +if [ "$param" == "tests" ]; then + cd "$DIR_TESTS" || { echo "无法进入目录 $DIR_TESTS"; exit 1; } + + # 创建全局任务队列 + init_task_queue OPS + + # 启动设备主进程 + for ((device_id=0; device_id < device_count; device_id++)); do + ( + export ASCEND_RT_VISIBLE_DEVICES=$device_id + run_device $device_id "tests" + ) & + done + +elif [ "$param" == "benchmark" ]; then + cd "$DIR_BENCHMARK" || { echo "无法进入目录 $DIR_BENCHMARK"; exit 1; } + + # 性能测试使用单线程模式(保证准确性) + if [ "$threads_per_device" -gt 1 ]; then + echo "警告:性能测试模式下自动设置为单线程模式(每个设备1个线程)" + threads_per_device=1 + fi + + # 创建全局任务队列 + init_task_queue OPS + + # 启动设备主进程 + for ((device_id=0; device_id < device_count; device_id++)); do + ( + export ASCEND_RT_VISIBLE_DEVICES=$device_id + run_device $device_id "benchmark" + ) & + done + +else + echo "参数错误! 用法:" + echo "正确性测试: $0 tests \"算子列表\" [设备数量] [线程数]" + echo "性能测试: $0 benchmark \"算子列表\" [设备数量] [线程数]" + cleanup_tasks + exit 1 +fi + +# 等待所有设备完成 +wait +cleanup_tasks + +# ===== 修改:改进的统计信息汇总 ===== +total_success=0 +total_failure=0 +total_skipped=0 +total_error=0 + +# 按设备汇总结果 +for ((device_id=0; device_id < device_count; device_id++)); do + stats_file="${STATS_DIR}/device_${device_id}.stats" + + if [ -f "$stats_file" ]; then + # 从文件加载统计 + d_success=$(grep '^success=' "$stats_file" | cut -d= -f2) + d_failure=$(grep '^failure=' "$stats_file" | cut -d= -f2) + d_skipped=$(grep '^skipped=' "$stats_file" | cut -d= -f2) + d_error=$(grep '^error=' "$stats_file" | cut -d= -f2) + + total_success=$((total_success + d_success)) + total_failure=$((total_failure + d_failure)) + total_skipped=$((total_skipped + d_skipped)) + total_error=$((total_error + d_error)) + + # 记录设备统计 + echo "设备 $device_id 完成情况: $d_success 成功, $d_failure 失败, $d_skipped 跳过, $d_error 错误" + else + echo "警告: 设备 $device_id 的统计文件未找到" + fi +done + +# 清理统计目录 +rm -rf "$STATS_DIR" + +# 计算总耗时 +total_time=$(( $(date +%s) - start_time )) # 使用绝对时间计算总耗时 +hours=$(( total_time / 3600 )) +minutes=$(( (total_time % 3600) / 60 )) +seconds=$(( total_time % 60 )) +time_str=$(printf "%02dh %02dm %02ds" $hours $minutes $seconds) + +# 计算平均耗时 +if [[ $total_ops -gt 0 ]]; then + completed_ops=$((total_success + total_failure + total_error)) + if [[ $completed_ops -gt 0 ]]; then + avg_time=$((total_time / completed_ops)) + avg_min=$((avg_time / 60)) + avg_sec=$((avg_time % 60)) + avg_str=$(printf "%02dm %02ds" $avg_min $avg_sec) + else + avg_str="N/A" + fi +else + avg_str="N/A" +fi + +# 生成统计信息摘要 +{ + echo "===================== flaggems测试统计摘要 =====================" + echo "执行类型: ${param^}" + echo "开始时间: $(date -d @$start_time '+%Y-%m-%d %H:%M:%S')" + echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')" + echo "测试日期: $(date '+%Y-%m-%d')" + echo "总耗时: $time_str" + echo "--------------------------------------------------------" + echo "总算子数: $total_ops" + echo "成功用例数: $total_success" + echo "失败用例数: $total_failure" + echo "跳过用例数: $total_skipped" + echo "错误用例数: $total_error" + echo "完成用例数: $((total_success + total_failure + total_error))" + + if [[ $total_ops -gt 0 ]]; then + echo "完成率: $(( (total_success + total_failure + total_error) * 100 / total_ops ))%" + else + echo "完成率: N/A" + fi + + if [[ $total_success -gt 0 ]] || [[ $total_failure -gt 0 ]] || [[ $total_error -gt 0 ]]; then + success_rate=$(( total_success * 100 / (total_success + total_failure + total_error) )) + echo "成功率: ${success_rate}%" + else + echo "成功率: N/A" + fi + + echo "平均耗时/算子: $avg_str" + echo "--------------------------------------------------------" + echo "设备数量: $device_count" + echo "每设备线程数: $threads_per_device" + echo "并行效率: $(( (total_success + total_failure + total_error) * 100 / (device_count * threads_per_device * total_time) )) OPS/线程秒" + echo "========================================================" + echo "" +} | tee -a $SUMMARY_FILE # 追加到统计文件并同时输出到控制台 + +# 归档所有日志文件 +log_dirs=($(find . -maxdepth 1 -type d -name "device_*_logs" 2>/dev/null)) +if [ ${#log_dirs[@]} -gt 0 ]; then + echo "归档日志文件到 $LOG_ARCHIVE" + tar -czf "$LOG_ARCHIVE" "${log_dirs[@]}" + + if mv "$LOG_ARCHIVE" "$DAILY_LOG_DIR"; then + echo "日志已保存到: $DAILY_LOG_DIR/$LOG_ARCHIVE" + else + echo "警告:日志移动到 $DAILY_LOG_DIR 失败" + fi + + # 清理临时日志 + rm -rf "${log_dirs[@]}" +else + echo "警告:未找到任何日志目录,跳过归档" +fi + +echo "所有算子测试执行完成!" +echo "详细统计信息已追加到: $SUMMARY_FILE" +exit 0 \ No newline at end of file diff --git a/third_party/ascend/examples/flaggems_cases/run_flaggems_test.sh b/third_party/ascend/examples/flaggems_cases/run_flaggems_test.sh new file mode 100644 index 000000000..24efa4ff5 --- /dev/null +++ b/third_party/ascend/examples/flaggems_cases/run_flaggems_test.sh @@ -0,0 +1,10 @@ + +TEST_flaggems="${WORKSPACE}/ascend/examples/flaggems_cases" +cd ${TEST_flaggems} +git init +git clone https://gitee.com/leopold0801/flaggems.git +cd flaggems +git checkout 4f3f548 +mv ../op_test_run.sh ./ +ls -al +bash op_test_run.sh tests fullop 8 32 diff --git a/third_party/ascend/examples/generalization_cases/acc_util.py b/third_party/ascend/examples/generalization_cases/acc_util.py new file mode 100644 index 000000000..96e2d2985 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/acc_util.py @@ -0,0 +1,121 @@ +import numpy as np +import torch +import torch_npu + +eval_standard = { + torch.float32: { + "rtol": 1e-6, + "small_value": 1e-6, + "small_value_atol": 1e-9, + "etol": 1e-4, + }, + torch.float16: { + "rtol": 1e-3, + "small_value": 1e-3, + "small_value_atol": 1e-5, + "etol": 1e-3, + }, + torch.bfloat16: { + "rtol": 4e-3, + "small_value": 1e-3, + "small_value_atol": 1e-5, + "etol": 1e-3, + }, +} + + +def assert_close(gold: torch.Tensor, act: torch.Tensor, eval_type: str = 'DEFAULT'): + gold = gold.cpu() + act = act.cpu() + if act.dtype == torch.float16 or act.dtype == torch.float32 or act.dtype == torch.bfloat16: + assert gold.dtype == torch.float32, "golden should be f32" + assert not (torch.isnan(act).any() or torch.isinf(act).any()), "actual tensor can not have 'inf' or 'nan'" + eps = eval_standard[act.dtype]['small_value'] + rtol = eval_standard[act.dtype]['rtol'] + atol = eval_standard[act.dtype]['small_value_atol'] + if eval_type == 'DEFAULT': + ae = torch.abs(act - gold) + re = ae / torch.abs(gold) + mask = torch.abs(gold) < eps + + print(f"count ae > {atol}: {(ae > atol).sum()}") + print(f"count re > {rtol}: {(re > rtol).sum()}") + + not_close = torch.where(mask, ae > atol, re > rtol) + print(f"count not_close = {torch.sum(not_close).item()}") + print(f"not_close.numel = {not_close.numel()}, gold.numel = {gold.numel()}") + print(f"not close ratio = {torch.sum(not_close).item() / not_close.numel()}") + if not torch.any(not_close): + return False + + assert torch.sum(not_close).item() < not_close.numel() * eps, "actual tensor are not close enough with golden tensor,\ +you can use 'benchmark_compare_close' function to compare again!" + elif eval_type == 'ABS': + act = act.to(gold.dtype) + assert torch.equal(gold, act), "actual tensor and golden tensor are not binary equal!" + else: + assert 0, "ERROR! invalid eval_type" + return False + + +def benchmark_compare_close(gold: torch.Tensor, act: torch.Tensor, std: torch.tensor): + assert act.dtype == std.dtype, "standard tensor's dtype must equal to actual tensor's dtype!" + if act.dtype == torch.float16 or act.dtype == torch.float32 or act.dtype == torch.bfloat16: + assert gold.dtype == torch.float32, "golden should be f32" + assert not (torch.isnan(act).any() or torch.isinf(act).any()), "actual tensor can not have 'inf' or 'nan'" + + gold = gold.cpu() + act = act.cpu() + std = std.cpu() + + eps = eval_standard[act.dtype]['small_value'] + atol = eval_standard[act.dtype]['small_value_atol'] + + mask = torch.abs(gold) <= eps + small_count = mask.sum().item() + + def calculate_relative_errors_except_small(tensor): + re = torch.abs(gold - tensor) / torch.abs(gold) + return torch.where(mask, 0, re) + + act_re = calculate_relative_errors_except_small(act) + std_re = calculate_relative_errors_except_small(std) + act_ae = torch.abs(gold - std) + std_ae = torch.abs(gold - std) + + # 小值域的定义为golden小于某个阈值 eps + act_small_error_count = (mask & (act_ae > atol)).sum().item() + std_small_error_count = (mask & (std_ae > atol)).sum().item() + act_total = act.numel() + std_total = std.numel() + + act_small_error_ratio = act_small_error_count / act_total + std_small_error_ratio = std_small_error_count / std_total + + def calculate_rmse(tensor): + dlt2 = (tensor - gold) ** 2 + dlt2_except_small_mean = torch.where(mask, 0, dlt2).sum() / small_count + return torch.sqrt(dlt2_except_small_mean) + + act_rmse = calculate_rmse(act) + std_rmse = calculate_rmse(std) + + print(f"act_re.max = {act_re.max()}, std_re.max = {std_re.max()}, limit ratio = 10") + print(f"act_re.sum = {act_re.sum()}, std_re.sum = {std_re.sum()}, limit_ratio = 2") + print( + f"act_small_error_ratio = {act_small_error_ratio}, std_small_error_ratio = {std_small_error_ratio}, limit_ratio = 2") + print(f"act_rmse = {act_rmse}, std_rmse = {std_rmse}, limit_ratio = 2") + + # 条件 1:actual 与 golden 相对误差最大值超过 10 倍 standard 与 golden 相对误差最大值 + assert act_re.max() <= 10 * std_re.max(), "actual re max > stdandard re max's 10 times" + + # 条件 2:actual 与 golden 相对误差均值超过 2 倍 standard 与 golden 相对误差均值 + assert act_re.sum() <= 2 * std_re.sum(), "actual re sum > stdandard re sum's 2 times" + + # 条件 3:actual 小值域 ERROR 占比超过 standard 的两倍 + assert act_small_error_ratio <= 2 * std_small_error_ratio, "act_small_error_ratio > std_small_error_ratio 's 2 times" + + # 条件 4:actual 均方根误差差于 standard 的两倍 + assert act_rmse <= 2 * std_rmse, "act_rmse > std_rmse 's 2 times" + + return False diff --git a/third_party/ascend/examples/generalization_cases/conftest.py b/third_party/ascend/examples/generalization_cases/conftest.py new file mode 100644 index 000000000..5e93b7182 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/conftest.py @@ -0,0 +1,14 @@ +import pytest +import torch + + +@pytest.fixture(scope="session", autouse=True) +def assign_npu(worker_id): + npu_count = torch.npu.device_count() + if worker_id == "master": + npu_id = 0 + else: + idx = int(worker_id.replace("gw", "")) + npu_id = idx % npu_count + torch.npu.set_device(npu_id) + diff --git a/third_party/ascend/examples/generalization_cases/full_run.sh b/third_party/ascend/examples/generalization_cases/full_run.sh new file mode 100755 index 000000000..b0eecba6f --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/full_run.sh @@ -0,0 +1,41 @@ +#!/bin/bash +current_date=$(date +%Y%m%d) +pid_log="process_${current_date}.log" +max_parallel=12 + +fifo="/tmp/$$.fifo" +mkfifo $fifo +exec 9<>$fifo +rm -f $fifo + +for ((i=0; i<$max_parallel; i++)); do + echo >&9 +done + +> "$pid_log" + + +if [ -d logs ]; then + rm -rf logs +fi + +mkdir logs + +while IFS= read -r -d $'\0' file; do + read -u 9 + + test_log="./logs/${file%.py}_${current_date}.log" + + { + pytest -sv "$file" -n 16 > "$test_log" 2>&1 + echo >&9 + } & + + echo "[INFO] Activated $(basename "$file"), PID=$!, logging into $test_log." + +done < <(find . -maxdepth 1 -type f -name "test_*.py" ! -name "test_common.py" -print0) + +wait +exec 9>&- + +echo "[INFO] All test processes completed, pids logged into ${pid_log}" diff --git a/third_party/ascend/examples/generalization_cases/test_abs.py b/third_party/ascend/examples/generalization_cases/test_abs.py new file mode 100644 index 000000000..d7a45e5f4 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_abs.py @@ -0,0 +1,144 @@ +import logging + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils, avoid_not_support +import math +import logging + + +def torch_pointwise(x0): + if x0.dtype != torch.uint32: + return torch.abs(x0) + else: + return torch.abs(x0.to(torch.float32)) + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.abs(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_abs_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.abs(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_pointwise(x) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) +def test_abs_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_pointwise(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_abs_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_advance.py b/third_party/ascend/examples/generalization_cases/test_advance.py new file mode 100644 index 000000000..9f14cb2b1 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_advance.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils + + +@triton.jit +def fn_npu_1d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(XB,), + strides=(1,), + offsets=(5,), + block_shape=(XB,), + order=(0,), + ) + bbptr = tl.advance(block_ptr_in, (-5,)) + # XB,YB,1 + X = tl.load(bbptr) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(XB,), + strides=(1,), + offsets=(0,), + block_shape=(XB,), + order=(0,), + ) + tl.store(block_ptr_out, X) + + +@triton.jit +def fn_npu_2d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xoffset = tl.program_id(0) + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(XB, YB), + strides=(YB, 1), + offsets=(6 + xoffset, 5), + block_shape=(XB, YB), + order=(1, 0), + ) + bbptr = tl.advance(block_ptr_in, (-6, -5)) + # XB,YB,1 + X = tl.load(bbptr) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(XB, YB), + strides=(YB, 1), + offsets=(xoffset, 0), + block_shape=(XB, YB), + order=(1, 0), + ) + tl.store(block_ptr_out, X) + + +@triton.jit +def fn_npu_3d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(XB, YB, ZB), + strides=(YB * ZB, ZB, 1), + offsets=(3, 1, 2), + block_shape=(XB, YB, ZB), + order=(2, 1, 0), + ) + bbptr = tl.advance(block_ptr_in, (-3, -1, -2)) + # XB,YB,1 + X = tl.load(bbptr) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(XB, YB, ZB), + strides=(YB * ZB, ZB, 1), + offsets=(0, 0, 0), + block_shape=(XB, YB, ZB), + order=(2, 1, 0), + ) + tl.store(block_ptr_out, X) + + +@triton.jit +def triton_advance_4d(output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, ): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), + strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), + offsets=(6, 5, 4, 3), + block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), + order=(3, 2, 1, 0), + ) + bbptr = tl.advance(block_ptr_in, (-6, -5, -4, -3)) + x = tl.load(bbptr) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), + strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), + offsets=(0, 0, 0, 0), + block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), + order=(3, 2, 1, 0), + ) + tl.store(block_ptr_out, x) + + +@triton.jit +def triton_advance_5d(output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr, ): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), + strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), + offsets=(6, 5, 4, 3, 2), + block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), + order=(4, 3, 2, 1, 0), + ) + bbptr = tl.advance(block_ptr_in, (-6, -5, -4, -3, -2)) + x = tl.load(bbptr) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), + strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), + offsets=(0, 0, 0, 0, 0), + block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), + order=(4, 3, 2, 1, 0), + ) + tl.store(block_ptr_out, x) + + +temporarily_not_support_dtype = ['bool'] + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.full_shape) +def test_npu(dtype, shape): + if dtype in temporarily_not_support_dtype: + return + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + + a = x + blocks = list(x.size()) + strides = list(x.stride()) + grid = (1,) + if len(shape) == 5: + triton_advance_5d[grid](output, x, *blocks, *blocks, *strides) + elif len(shape) == 4: + triton_advance_4d[grid](output, x, *blocks, *blocks, *strides) + elif len(shape) == 3: + fn_npu_3d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=shape[2]) + elif len(shape) == 2: + if x.numel() * x.element_size() > 8192: + fn_npu_2d[shape[0], 1, 1](output, x, y, z, output1, XB=1, YB=shape[1], ZB=1) + else: + fn_npu_2d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=1) + else: + fn_npu_1d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=1, ZB=1) + + torch.testing.assert_close(output, a) diff --git a/third_party/ascend/examples/generalization_cases/test_and.py b/third_party/ascend/examples/generalization_cases/test_and.py new file mode 100644 index 000000000..75c64cdfd --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_and.py @@ -0,0 +1,146 @@ +import logging +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_pointwise(x, y): + res = x & y + return res + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X & Y + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_and_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val & y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_case2(dtype, shape): + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_pointwise(x, y) + output = torch.zeros_like(ans) + + if len(shape) == 1: + fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_and_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x & y + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_and_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) + +invalid_types = [ + 'float16', + 'float32', + 'bfloat16', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + z = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/examples/generalization_cases/test_argmax.py b/third_party/ascend/examples/generalization_cases/test_argmax.py new file mode 100644 index 000000000..7b40d7058 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_argmax.py @@ -0,0 +1,328 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +import logging +import math +import pytest +import torch +import torch_npu +import numpy as np +import triton +import triton.language as tl + +import test_common +from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size + +logger = logging.getLogger(__name__) + + +# <<<<<<< test_argmax_1d +def torch_argmax(x0, dim, keepdim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + return torch.argmax(x0, dim=dim, keepdim=keepdim).npu() + + +@triton.jit +def triton_argmax_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None) + tmp4 = tl.argmax(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +def test_argmax_1d(dtype, shape): + dtype_size = get_dtype_size(dtype) + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): + logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") + return + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty(1, dtype=torch.int32).npu() + numel = shape[0] + triton_argmax_1d[1, 1, 1](x0, triton_res, numel, numel) + torch_res = torch_argmax(x0, dim=0, keepdim=True) + test_common.validate_cmp("int32", triton_res, torch_res) + + +# >>>>>>> test_argmax_1d + +# <<<<<<< test_argmax_2d +@triton.jit +def triton_argmax_2d(in_ptr0, out_ptr0, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) + tmp4 = tl.argmax(x, dim) + if dim == 0: + tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) + else: + tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +@pytest.mark.parametrize('dim', [0, 1]) +def test_argmax_2d(dtype, shape, dim): + dtype_size = get_dtype_size(dtype) + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): + logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") + return + shapex, shapey = shape + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[1 - dim], ], dtype=torch.int32).npu() + triton_argmax_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) + torch_res = torch_argmax(x0, dim=dim, keepdim=False) + test_common.validate_cmp("int32", triton_res, torch_res) + + +# >>>>>>> test_argmax_2d + +# <<<<<<< test_argmax_3d +def torch_argmax_3d(x0, no_reduce_dim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + if no_reduce_dim == 0: + return torch.argmax(torch.max(x0, 1)[0], 1).npu() + elif no_reduce_dim == 1: + return torch.argmax(torch.max(x0, 0)[0], 1).npu() + elif no_reduce_dim == 2: + return torch.argmax(torch.max(x0, 0)[0], 0).npu() + else: + assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" + + +@triton.jit +def triton_argmax_3d_0_1(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.max(x, 0) + ret = tl.argmax(tmp, 0) + oidx = zidx + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def triton_argmax_3d_0_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.max(x, 0) + ret = tl.argmax(tmp, 1) + oidx = yidx + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def triton_argmax_3d_1_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.max(x, 1) + ret = tl.argmax(tmp, 1) + oidx = xidx + tl.store(out_ptr + oidx, ret) + + +def triton_argmax_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): + if no_reduce_dim == 0: + triton_argmax_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 1: + triton_argmax_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 2: + triton_argmax_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) +def test_argmax_3d(dtype, shape, no_reduce_dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[no_reduce_dim], ], dtype=torch.int32).npu() + triton_argmax_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) + torch_res = torch_argmax_3d(x0, no_reduce_dim) + test_common.validate_cmp("int32", triton_res, torch_res) + +# >>>>>>> test_argmax_3d + + +# <<<<<<< test_argmax_4d +def torch_argmax_4d(x0, dim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + return torch.argmax(x0, dim) + + +@triton.jit +def argmax_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): + if DIM == 0: + ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // XB) + o_idx = tl.arange(0, XB * YB * ZB * MB // XB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 1: + ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // YB) + o_idx = tl.arange(0, XB * YB * ZB * MB // YB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 2: + ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // ZB) + o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) + tl.store(out_ptr + o_idx, ret) + else: + ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB // MB) + o_idx = tl.arange(0, XB * YB * ZB * MB // MB) + tl.store(out_ptr + o_idx, ret) + + +@triton.jit +def triton_argmax_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + + idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[None, None, :, None] * MB + midx[None, None, None, :] + + x = tl.load(in_ptr + idx) + + argmax_4d(out_ptr, x, XB, YB, ZB, MB, DIM) + + +def triton_argmax_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): + triton_argmax_kernel_4d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, dim) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 2, 4, 8), + (2, 3, 4, 8), +]) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +@pytest.mark.parametrize('dim', [0]) +def test_argmax_4d(dtype, shape, dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_argmax_4d(x0, dim).to(torch.int32) + triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() + triton_argmax_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) + + test_common.validate_cmp("int32", triton_res, torch_res) +# >>>>>>> test_argmax_4d + + +# <<<<<<< test_argmax_5d +def torch_argmax_5d(x0, dim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + return torch.argmax(x0, dim) + + +@triton.jit +def argmax_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr): + if DIM == 0: + ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // XB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 1: + ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // YB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 2: + ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // ZB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 3: + ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // MB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) + tl.store(out_ptr + o_idx, ret) + else: + ret = tl.reshape(tl.argmax(x, DIM), XB * YB * ZB * MB * NB // NB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) + tl.store(out_ptr + o_idx, ret) + + +@triton.jit +def triton_argmax_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + nidx = tl.arange(0, NB) + + idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] + + x = tl.load(in_ptr + idx) + + argmax_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) + + +def triton_argmax_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): + triton_argmax_kernel_5d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 2, 2, 4, 8), + (2, 2, 3, 4, 8), +]) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +@pytest.mark.parametrize('dim', [0]) +def test_argmax_5d(dtype, shape, dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_argmax_5d(x0, dim).to(torch.int32) + triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() + triton_argmax_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) + + test_common.validate_cmp("int32", triton_res, torch_res) +# >>>>>>> test_argmax_5d + + +# <<<<<<< test_argmax_1d_bool +@triton.jit +def triton_argmax_1d_bool(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None).to(tl.int1) + tmp4 = tl.argmax(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['bool']) +def test_argmax_1d_bool(dtype, shape): + dtype_size = get_dtype_size(dtype) + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): + logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") + return + x0 = test_common.generate_tensor(shape, dtype) + triton_res = torch.empty(1, dtype=torch.int32).npu() + numel = shape[0] + triton_argmax_1d_bool[1, 1, 1](x0.npu(), triton_res, numel, numel) + np_res = np.argmax(x0.numpy()) + np.equal(triton_res.item(), np_res) +# >>>>>>> test_argmax_1d_bool \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_argmin.py b/third_party/ascend/examples/generalization_cases/test_argmin.py new file mode 100644 index 000000000..e48e135ca --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_argmin.py @@ -0,0 +1,326 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +import logging +import math +import pytest +import torch +import torch_npu +import numpy as np +import triton +import triton.language as tl + +import test_common +from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size + +logger = logging.getLogger(__name__) + + +# <<<<<<< test_argmin_1d +def torch_argmin(input_tensor, dim, keepdim): + return torch.argmin(input_tensor, dim=dim, keepdim=keepdim) + + +@triton.jit +def triton_argmin_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None) + tmp4 = tl.argmin(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +def test_argmin_1d(dtype, shape): + dtype_size = get_dtype_size(dtype) + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): + logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") + return + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty(1, dtype=torch.int32).npu() + numel = shape[0] + triton_argmin_1d[1, 1, 1](x0, triton_res, numel, numel) + torch_res = torch_argmin(x0, dim=0, keepdim=True) + test_common.validate_cmp("int32", triton_res, torch_res) + + +# >>>>>>> test_argmin_1d + +# <<<<<<< test_argmin_2d +@triton.jit +def triton_argmin_2d(in_ptr0, out_ptr0, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) + tmp4 = tl.argmin(x, dim) + if dim == 0: + tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) + else: + tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +@pytest.mark.parametrize('dim', [0, 1]) +def test_argmin_2d(dtype, shape, dim): + dtype_size = get_dtype_size(dtype) + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): + logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") + return + shapex, shapey = shape + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[1 - dim], ], dtype=torch.int32).npu() + triton_argmin_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) + torch_res = torch_argmin(x0, dim=dim, keepdim=False) + test_common.validate_cmp("int32", triton_res, torch_res) + + +# >>>>>>> test_argmin_2d + +# <<<<<<< test_argmin_3d +def torch_argmin_3d(x0, no_reduce_dim): + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + if no_reduce_dim == 0: + return torch.argmin(torch.min(x0, 1)[0], 1) + elif no_reduce_dim == 1: + return torch.argmin(torch.min(x0, 0)[0], 1) + elif no_reduce_dim == 2: + return torch.argmin(torch.min(x0, 0)[0], 0) + else: + assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" + + +@triton.jit +def triton_argmin_3d_0_1(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.min(x, 0) + ret = tl.argmin(tmp, 0) + oidx = zidx + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def triton_argmin_3d_0_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.min(x, 0) + ret = tl.argmin(tmp, 1) + oidx = yidx + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def triton_argmin_3d_1_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.min(x, 1) + ret = tl.argmin(tmp, 1) + oidx = xidx + tl.store(out_ptr + oidx, ret) + + +def triton_argmin_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): + if no_reduce_dim == 0: + triton_argmin_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 1: + triton_argmin_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 2: + triton_argmin_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) +def test_argmin_3d(dtype, shape, no_reduce_dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[no_reduce_dim], ], dtype=torch.int32).npu() + triton_argmin_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) + torch_res = torch_argmin_3d(x0, no_reduce_dim) + test_common.validate_cmp("int32", triton_res, torch_res) + +# >>>>>>> test_argmin_3d + + +# <<<<<<< test_argmin_4d +def torch_argmin_4d(x0, dim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + return torch.argmin(x0, dim) + + +@triton.jit +def argmin_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): + if DIM == 0: + ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // XB) + o_idx = tl.arange(0, XB * YB * ZB * MB // XB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 1: + ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // YB) + o_idx = tl.arange(0, XB * YB * ZB * MB // YB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 2: + ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // ZB) + o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) + tl.store(out_ptr + o_idx, ret) + else: + ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB // MB) + o_idx = tl.arange(0, XB * YB * ZB * MB // MB) + tl.store(out_ptr + o_idx, ret) + + +@triton.jit +def triton_argmin_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + + idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[None, None, :, None] * MB + midx[None, None, None, :] + + x = tl.load(in_ptr + idx) + + argmin_4d(out_ptr, x, XB, YB, ZB, MB, DIM) + + +def triton_argmin_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): + triton_argmin_kernel_4d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, dim) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 2, 4, 8), + (2, 3, 4, 8), +]) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +@pytest.mark.parametrize('dim', [0]) +def test_argmin_4d(dtype, shape, dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_argmin_4d(x0, dim).to(torch.int32) + triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() + triton_argmin_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) + + test_common.validate_cmp("int32", triton_res, torch_res) +# >>>>>>> test_argmin_4d + + +# <<<<<<< test_argmin_5d +def torch_argmin_5d(x0, dim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + return torch.argmin(x0, dim) + + +@triton.jit +def argmin_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr): + if DIM == 0: + ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // XB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 1: + ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // YB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 2: + ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // ZB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 3: + ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // MB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) + tl.store(out_ptr + o_idx, ret) + else: + ret = tl.reshape(tl.argmin(x, DIM), XB * YB * ZB * MB * NB // NB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) + tl.store(out_ptr + o_idx, ret) + + +@triton.jit +def triton_argmin_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + nidx = tl.arange(0, NB) + + idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] + + x = tl.load(in_ptr + idx) + + argmin_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) + + +def triton_argmin_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): + triton_argmin_kernel_5d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 2, 2, 4, 8), + (2, 2, 3, 4, 8), +]) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +@pytest.mark.parametrize('dim', [0]) +def test_argmin_5d(dtype, shape, dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_argmin_5d(x0, dim).to(torch.int32) + triton_res = torch.empty_like(torch_res, dtype=torch.int32).npu() + triton_argmin_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) + + test_common.validate_cmp("int32", triton_res, torch_res) +# >>>>>>> test_argmin_5d + + +# <<<<<<< test_argmin_1d_bool +@triton.jit +def triton_argmin_1d_bool(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None).to(tl.int1) + tmp4 = tl.argmin(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['bool']) +def test_argmin_1d_bool(dtype, shape): + dtype_size = get_dtype_size(dtype) + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): + logger.warning(f"dtype:{dtype} shape:{shape} mem overflow") + return + x0 = test_common.generate_tensor(shape, dtype) + triton_res = torch.empty(1, dtype=torch.int32).npu() + numel = shape[0] + triton_argmin_1d_bool[1, 1, 1](x0.npu(), triton_res, numel, numel) + np_res = np.argmin(x0.numpy()) + np.equal(triton_res.item(), np_res) +# >>>>>>> test_argmin_1d_bool \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_associative_scan.py b/third_party/ascend/examples/generalization_cases/test_associative_scan.py new file mode 100644 index 000000000..5c314a1a5 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_associative_scan.py @@ -0,0 +1,528 @@ +import math +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +import test_common +from test_common import TestUtils, get_dtype_size + + +def combine_fn_test_torch(a, b, combine_fn): + if combine_fn == 'maximum_fn': + return torch.maximum(a, b) # 最大值 + elif combine_fn == 'minimum_fn': + return torch.minimum(a, b) # 最小值 + elif combine_fn == 'bitwise_xor_fn': + return a ^ b # 按位异或 + elif combine_fn == 'bitwise_or_fn': + return a | b # 按位异 + elif combine_fn == 'bitwise_and_fn': + return a & b # 按位与 + else: + pytest.skip("The combine_fn is not within the following scope , skipping.") + + +def torch_func_scan(input: torch.Tensor, dim: int, combine_fn='maximum', reverse=False): + """ + PyTorch 实现 associative_scan,语义与 Triton 完全对齐 + 支持任意 combine_fn(如 a|b, a&b, min, max 等) + """ + dim = dim % input.ndim + + if reverse: + input = input.flip(dim) + + N = input.size(dim) + + tensors = torch.unbind(input, dim=dim) + + outputs = [] + + carry = tensors[0] + outputs.append(carry) + + for i in range(1, N): + carry = combine_fn_test_torch(tensors[i], carry, combine_fn) + outputs.append(carry) + + output = torch.stack(outputs, dim=dim) + + if reverse: + output = output.flip(dim) + + return output + + +@triton.jit +def bitwise_and_fn(a, b): + return a & b + + +@triton.jit +def bitwise_or_fn(a, b): + return a | b + + +@triton.jit +def bitwise_xor_fn(a, b): + return a ^ b + + +@triton.jit +def minimum_fn(a, b): + return tl.minimum(a, b) + + +@triton.jit +def maximum_fn(a, b): + return tl.maximum(a, b) + + +@triton.jit +def triton_kernel_1d_scan( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + XBLOCK: tl.constexpr, + combine_fn_name: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + idx = tl.arange(0, XBLOCK) + x = tl.load(in_ptr0 + idx) + if combine_fn_name == "maximum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) + elif combine_fn_name == "minimum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) + elif combine_fn_name == "bitwise_or_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) + elif combine_fn_name == "bitwise_xor_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) + elif combine_fn_name == "bitwise_and_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) + + tl.store(out_ptr0 + idx, ret) + + +@triton.jit +def triton_kernel_2d_scan( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, + combine_fn_name: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx = idx_x[:, None] * numel_r + idx_r[None, :] + x = tl.load(in_ptr0 + idx) + + if combine_fn_name == "maximum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) + elif combine_fn_name == "minimum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) + elif combine_fn_name == "bitwise_or_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) + elif combine_fn_name == "bitwise_xor_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) + elif combine_fn_name == "bitwise_and_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) + tl.store(out_ptr0 + idx, ret) + + +@triton.jit +def triton_kernel_3d_scan( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + numel_z: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, + ZBLOCK: tl.constexpr, + combine_fn_name: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + tl.static_assert( + numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx_z = tl.arange(0, ZBLOCK) + idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] + x = tl.load(in_ptr0 + idx) + if combine_fn_name == "maximum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) + elif combine_fn_name == "minimum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) + elif combine_fn_name == "bitwise_or_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) + elif combine_fn_name == "bitwise_xor_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) + elif combine_fn_name == "bitwise_and_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) + tl.store(out_ptr0 + idx, ret) + + +@triton.jit +def triton_kernel_4d_scan( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + XB: tl.constexpr, + YB: tl.constexpr, + ZB: tl.constexpr, + MB: tl.constexpr, + combine_fn_name: tl.constexpr, +): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + idx = (xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + + zidx[None, None, :, None] * MB + midx[None, None, None, :]) + x = tl.load(in_ptr0 + idx) + if combine_fn_name == "maximum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) + elif combine_fn_name == "minimum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) + elif combine_fn_name == "bitwise_or_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) + elif combine_fn_name == "bitwise_xor_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) + elif combine_fn_name == "bitwise_and_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) + tl.store(out_ptr0 + idx, ret) + + +@triton.jit +def triton_kernel_5d_scan( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + XB: tl.constexpr, + YB: tl.constexpr, + ZB: tl.constexpr, + MB: tl.constexpr, + NB: tl.constexpr, + combine_fn_name: tl.constexpr, +): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + nidx = tl.arange(0, NB) + idx = (xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + + zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + + nidx[None, None, None, None, :]) + x = tl.load(in_ptr0 + idx) + if combine_fn_name == "maximum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=maximum_fn) + elif combine_fn_name == "minimum_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=minimum_fn) + elif combine_fn_name == "bitwise_or_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_or_fn) + elif combine_fn_name == "bitwise_xor_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_xor_fn) + elif combine_fn_name == "bitwise_and_fn": + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=bitwise_and_fn) + tl.store(out_ptr0 + idx, ret) + + +def triton_func_scan(x, dim, combine_fn, reverse): + res = torch.empty_like(x) + shape = x.size() + + if len(shape) == 1: + if dim >= 1: + pytest.skip("dim >= 1 for 1D tensor, skipping.") + triton_kernel_1d_scan[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[0], combine_fn + ) + elif len(shape) == 2: + if dim >= 2: + pytest.skip("dim >= 2 for 2D tensor, skipping.") + triton_kernel_2d_scan[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1], combine_fn + ) + elif len(shape) == 3: + if dim >= 3: + pytest.skip("dim >= 3 for 3D tensor, skipping.") + triton_kernel_3d_scan[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], x.shape[2], combine_fn + ) + elif len(shape) == 4: + if dim >= 4: + pytest.skip("dim >= 4 for 4D tensor, skipping.") + triton_kernel_4d_scan[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], combine_fn + ) + elif len(shape) == 5: + if dim >= 5: + pytest.skip("dim >= 5 for 5D tensor, skipping.") + triton_kernel_5d_scan[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4], combine_fn + ) + else: + pytest.skip(f"Unsupported tensor dimension: {len(shape)}") + + return res + + +def should_skip_due_to_mem(dtype, shape): + dtype_size = get_dtype_size(dtype) + total_mem = dtype_size * math.prod(shape) + if dtype in ('int8', 'bool'): + threshold = TestUtils.ub_size / 13 + else: + threshold = TestUtils.ub_size / 6 + + if total_mem >= threshold: + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + + +@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) +@pytest.mark.parametrize("shape", TestUtils.test_shape1d) +@pytest.mark.parametrize("dim", [0]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_1d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) +@pytest.mark.parametrize("shape", TestUtils.test_shape2d) +@pytest.mark.parametrize("dim", [1]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_2d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) +@pytest.mark.parametrize("shape", TestUtils.test_shape3d) +@pytest.mark.parametrize("dim", [2]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_3d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) +@pytest.mark.parametrize("shape", TestUtils.test_shape4d) +@pytest.mark.parametrize("dim", [3]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_4d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['int8', 'int16', 'int32', 'int64', 'bool']) +@pytest.mark.parametrize("shape", TestUtils.test_shape5d) +@pytest.mark.parametrize("dim", [4]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_5d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("shape", TestUtils.test_shape1d) +@pytest.mark.parametrize("dim", [0]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("shape", TestUtils.test_shape2d) +@pytest.mark.parametrize("dim", [1]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_float_2d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("shape", TestUtils.test_shape3d) +@pytest.mark.parametrize("dim", [2]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("shape", TestUtils.test_shape4d) +@pytest.mark.parametrize("dim", [3]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("shape", TestUtils.test_shape5d) +@pytest.mark.parametrize("dim", [4]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn']) +@pytest.mark.parametrize("reverse", [False]) +def test_scan_float_1d(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("shape", TestUtils.test_shape1d) +@pytest.mark.parametrize("dim", [0]) +@pytest.mark.parametrize("combine_fn", + ['bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) +@pytest.mark.parametrize("reverse", [False]) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_scan_float_invalid(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + +@pytest.mark.parametrize("dtype", ['int32']) +@pytest.mark.parametrize("shape", TestUtils.test_shape1d) +@pytest.mark.parametrize("dim", [0]) +@pytest.mark.parametrize("combine_fn", + ['maximum_fn', 'minimum_fn', 'bitwise_or_fn', 'bitwise_xor_fn', 'bitwise_and_fn']) +@pytest.mark.parametrize("reverse", [True]) +@test_common.raises_with_match(triton.compiler.errors.MLIRCompilationError, + "reverse=True is not yet supported for scan op") +def test_scan_float_invalid_reverse(dtype, shape, dim, combine_fn, reverse): + should_skip_due_to_mem(dtype, shape) + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, combine_fn, reverse) + diff --git a/third_party/ascend/examples/generalization_cases/test_atan.py b/third_party/ascend/examples/generalization_cases/test_atan.py new file mode 100644 index 000000000..cc45f0348 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_atan.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math +import triton.language.extra.ascend.libdevice as libdevice +def torch_pointwise(x0): + res = torch.atan(x0) + return res + +@triton.jit +def triton_atan(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp2 = libdevice.atan(tmp0) + tl.store(out_ptr0 + (x0), tmp2, None) + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['float32', 'float16']) +def test_case(dtype, shape): + x0 = test_common.generate_tensor(shape, dtype).npu() + + numel = x0.numel() + ncore = 1 if numel <= 32 else 32 + xblock = math.ceil(numel / ncore) + xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) + + y_ref = torch_pointwise(x0) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_atan[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_add.py b/third_party/ascend/examples/generalization_cases/test_atomic_add.py new file mode 100644 index 000000000..23c225d82 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_atomic_add.py @@ -0,0 +1,510 @@ +import math +import pytest +import torch +import triton + +import triton.language as tl + +import test_common +from test_common import TestUtils +filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'int64', 'bool'}] + + +@triton.jit +def atomic_add(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): + offset = tl.program_id(0) * BLOCK_SIZE + index = offset + tl.arange(0, BLOCK_SIZE)[:] + xmask = index < n_elements + + tmp0 = tl.load(in_ptr0 + (index), xmask) + tmp1 = tl.load(out_ptr0 + (index), xmask) + tl.atomic_add(out_ptr1 + (index), tmp0, xmask) + tl.atomic_add(out_ptr1 + (index), tmp1, xmask) + + + +@triton.jit +def atomic_add_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + x = tl.load(x_ptr) # x is scalar or 1D, no mask needed + + # Compute y indices + y_offset = pid * BLOCK_SIZE + y_indices = y_offset + tl.arange(0, BLOCK_SIZE) + y_mask = y_indices < n_elements + + y_value = tl.load(y_ptr + y_indices, y_mask) + # Atomic add: y += x (broadcasted) + tl.atomic_add(out_ptr + y_indices, y_value, mask=y_mask) + tl.atomic_add(out_ptr + y_indices, x, mask=y_mask) + + +# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) +test_cases = [ + ((1, 1, 1, 1), (1, 1, 1, 4), 4), + ((1, 1, 1, 3), (1, 5, 1, 3), 5), + ((3,), (2, 3, 3, 3, 3), 81), + ((3,), (2, 3, 3, 3), 27), + ((3,), (2, 3, 3), 9), + ((3,), (2, 3), 3), +] + + +def promote_dtype(x_dtype, y_dtype): + """ + 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 + """ + # 如果两个数据类型一致,直接返回 + if x_dtype == y_dtype: + return y_dtype + + # 构建类型的优先级列表(从低到高) + priority = [ + torch.int8, torch.int16, torch.int32, + torch.float16, torch.bfloat16, torch.float32 + ] + + # 查找两种类型在优先级列表中的位置 + x_priority = priority.index(x_dtype) + y_priority = priority.index(y_dtype) + + # 如果y的优先级比x小,则提升到x的类型 + if y_priority < x_priority: + return x_dtype + else: + return y_dtype + + +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +@pytest.mark.parametrize('y_dtype_str', filtered_dtype) +@pytest.mark.parametrize('x_shape, y_shape, BLOCK_SIZE', test_cases) +def test_atomic_add_broadcast_combined(x_dtype_str, y_dtype_str, x_shape, y_shape, BLOCK_SIZE): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + # 先构造 x0 + x0 = torch.full(x_shape, 83.0000, dtype=x_dtype).npu() + + y_raw_dtype = eval('torch.' + y_dtype_str) + + out_dtype = promote_dtype(x_dtype, y_raw_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + + # 构造y和out + y = torch.full(y_shape, -105, dtype=y_raw_dtype).npu() + out = torch.full(y_shape, 0, dtype=out_dtype).npu() + + # 保存副本用于验证 + x_temp = x0.clone() + y_temp = y.clone() + out_temp = out.clone() + + # 计算网格大小和元素总数 + n_elements = y.numel() + grid = (n_elements // BLOCK_SIZE,) # 自动计算需要的线程块数量 + + # 调用 Triton 核函数 + atomic_add_broadcast[grid]( + x_ptr=x0, + y_ptr=y, + out_ptr=out, + n_elements=n_elements, + BLOCK_SIZE=BLOCK_SIZE + ) + + # 验证结果:y += x (广播加法) + expected = out_temp + y_temp + x_temp + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +@pytest.mark.parametrize('y_dtype_str', filtered_dtype) +def test_atomic_add(x_dtype_str, y_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + x0 = test_common.generate_tensor(shape, x_dtype_str).npu() + x1 = test_common.generate_tensor(shape, y_dtype_str).npu() + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + y = torch.full(x1.shape, 0, dtype=out_dtype).npu() + + # 保存副本用于验证 + x0_temp = x0.clone() + x1_temp = x1.clone() + y_temp = y.clone() + + if len(shape) == 2: + n_elements = shape[0] * shape[1] + atomic_add[shape[0], 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=shape[1]) + elif len(shape) == 1: + n_elements = shape[0] + BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 + grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 + atomic_add[grid_size, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + expected = y_temp + x1_temp + x0_temp + torch.testing.assert_close(y, expected) + + +# 3d +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +@pytest.mark.parametrize('y_dtype_str', filtered_dtype) +def test_atomic_add_3d(x_dtype_str, y_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + x0 = test_common.generate_tensor(shape, x_dtype_str).npu() + x1 = test_common.generate_tensor(shape, y_dtype_str).npu() + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + y = torch.full(x1.shape, 0, dtype=out_dtype).npu() + + # 保存副本用于验证 + x0_temp = x0.clone() + x1_temp = x1.clone() + y_temp = y.clone() + + n_elements = shape[0] * shape[1] * shape[2] + atomic_add[1, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=shape[0] * shape[1] * shape[2]) + + expected = y_temp + x1_temp + x0_temp + torch.testing.assert_close(y, expected) + + +@triton.jit +def atomic_add_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tl.atomic_add(out_ptr0 + offsets, tmp0) + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('dtype', filtered_dtype) +def test_atomic_add_4d_5d(dtype, shape): + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() + x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() + + x1_ref = x1 + x0_value + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + atomic_add_multi_d[(1, )](x0, x1, *triton_shape) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@triton.jit +def atomic_add_5d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr, NB1: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1 * NB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1 * NB1) + offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1 * NB1) + offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] * NB1 + offsets1 = offsets1[:, :, :, :, None] + tl.arange(0, NB1)[None, None, None, None, :] + + tmp0 = tl.load(x_ptr + offsets) + tmp1 = tl.load(y_ptr + offsets1) + tl.atomic_add(out_ptr + offsets1, tmp0) + tl.atomic_add(out_ptr + offsets1, tmp1) + + +@pytest.mark.parametrize('param_list', + [ + [(1, 1, 2, 1, 1), (1, 1, 2, 1, 2)], + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +@pytest.mark.parametrize('y_dtype_str', filtered_dtype) +def test_atomic_add_5d(x_dtype_str, y_dtype_str, param_list): + x0_shape, y_shape = param_list + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: + x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() + else: + x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() + + if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: + y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() + else: + y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() + + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + out = torch.full(y_shape, 0, dtype=out_dtype).npu() + + x0_temp = x0.clone() + y_temp = y.clone() + out_temp = out.clone() + + triton_shape = [*x0_shape] + while len(triton_shape) < 5: + triton_shape.append(1) + XB, YB, ZB, MB, NB = triton_shape + + triton_shape1 = [*y_shape] + while len(triton_shape1) < 5: + triton_shape1.append(1) + XB1, YB1, ZB1, MB1, NB1 = triton_shape1 + + atomic_add_5d[(1, )]( + x_ptr=x0, + y_ptr=y, + out_ptr=out, + XB=XB, YB=YB, ZB=ZB, MB=MB, NB=NB, + XB1=XB1, YB1=YB1, ZB1=ZB1, MB1=MB1, NB1=NB1, + ) + + expected = out_temp + y_temp + x0_temp + torch.testing.assert_close(out, expected) + + +@triton.jit +def atomic_add_4d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB) + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB) + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] + + offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1) + offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1) + offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] + + tmp0 = tl.load(x_ptr + offsets) + tmp1 = tl.load(y_ptr + offsets1) + tl.atomic_add(out_ptr + offsets1, tmp0) + tl.atomic_add(out_ptr + offsets1, tmp1) + + +@pytest.mark.parametrize('param_list', + [ + [(1, 1, 2, 1), (1, 1, 2, 2)], + [(1, 1, 1, 1), (1, 1, 2, 2)], + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +@pytest.mark.parametrize('y_dtype_str', filtered_dtype) +def test_atomic_add_4d(x_dtype_str, y_dtype_str, param_list): + x0_shape, y_shape = param_list + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: + x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() + else: + x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() + + if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: + y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() + else: + y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() + + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + out = torch.full(y_shape, 0, dtype=out_dtype).npu() + + x0_temp = x0.clone() + y_temp = y.clone() + out_temp = out.clone() + + triton_shape = [*x0_shape] + while len(triton_shape) < 4: + triton_shape.append(1) + XB, YB, ZB, MB = triton_shape + + triton_shape1 = [*y_shape] + while len(triton_shape1) < 4: + triton_shape1.append(1) + XB1, YB1, ZB1, MB1 = triton_shape1 + + atomic_add_4d[(1, )]( + x_ptr=x0, + y_ptr=y, + out_ptr=out, + XB=XB, YB=YB, ZB=ZB, MB=MB, + XB1=XB1, YB1=YB1, ZB1=ZB1, MB1=MB1, + ) + + expected = out_temp + y_temp + x0_temp + torch.testing.assert_close(out, expected) + + +@triton.jit +def atomic_add_3d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB) + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] + + offsets1 = tl.arange(0, XB1) * (YB1 * ZB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1) + offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] + + tmp0 = tl.load(x_ptr + offsets) + tmp1 = tl.load(y_ptr + offsets1) + tl.atomic_add(out_ptr + offsets1, tmp0) + tl.atomic_add(out_ptr + offsets1, tmp1) + + +@pytest.mark.parametrize('param_list', + [ + [(1, 1, 2), (1, 2, 2)], + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +@pytest.mark.parametrize('y_dtype_str', filtered_dtype) +def test_atomic_add_3d_2(x_dtype_str, y_dtype_str, param_list): + x0_shape, y_shape = param_list + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: + x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() + else: + x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() + + if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: + y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() + else: + y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() + + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + out = torch.full(y_shape, 0, dtype=out_dtype).npu() + + x0_temp = x0.clone() + y_temp = y.clone() + out_temp = out.clone() + + triton_shape = [*x0_shape] + while len(triton_shape) < 3: + triton_shape.append(1) + XB, YB, ZB = triton_shape + + triton_shape1 = [*y_shape] + while len(triton_shape1) < 3: + triton_shape1.append(1) + XB1, YB1, ZB1 = triton_shape1 + + atomic_add_3d[(1, )]( + x_ptr=x0, + y_ptr=y, + out_ptr=out, + XB=XB, YB=YB, ZB=ZB, + XB1=XB1, YB1=YB1, ZB1=ZB1, + ) + + expected = out_temp + y_temp + x0_temp + torch.testing.assert_close(out, expected) + + +@triton.jit +def atomic_add_2d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr): + offsets = tl.arange(0, XB) * (YB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] + + offsets1 = tl.arange(0, XB1) * (YB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] + + tmp0 = tl.load(x_ptr + offsets) + tmp1 = tl.load(y_ptr + offsets1) + tl.atomic_add(out_ptr + offsets1, tmp0) + tl.atomic_add(out_ptr + offsets1, tmp1) + + +@pytest.mark.parametrize('param_list', + [ + [(1, 2), (2, 2)], + [(1, 1), (2, 2)], + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +@pytest.mark.parametrize('y_dtype_str', filtered_dtype) +def test_atomic_add_2d(x_dtype_str, y_dtype_str, param_list): + x0_shape, y_shape = param_list + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: + x0 = torch.randint(low=0, high=100, size=x0_shape, dtype=x_dtype).npu() + else: + x0 = torch.randn(x0_shape, dtype=eval('torch.' + x_dtype_str)).npu() + + if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: + y = torch.randint(low=0, high=100, size=y_shape, dtype=y_dtype).npu() + else: + y = torch.randn(y_shape, dtype=eval('torch.' + y_dtype_str)).npu() + + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + out = torch.full(y_shape, 0, dtype=out_dtype).npu() + + x0_temp = x0.clone() + y_temp = y.clone() + out_temp = out.clone() + + triton_shape = [*x0_shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB, YB = triton_shape + + triton_shape1 = [*y_shape] + while len(triton_shape1) < 2: + triton_shape1.append(1) + XB1, YB1 = triton_shape1 + + atomic_add_2d[(1, )]( + x_ptr=x0, + y_ptr=y, + out_ptr=out, + XB=XB, YB=YB, + XB1=XB1, YB1=YB1, + ) + + expected = out_temp + y_temp + x0_temp + torch.testing.assert_close(out, expected) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_and.py b/third_party/ascend/examples/generalization_cases/test_atomic_and.py new file mode 100644 index 000000000..d532324b1 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_atomic_and.py @@ -0,0 +1,513 @@ +import math +import pytest +import torch +import triton + +import triton.language as tl + +import test_common +from test_common import TestUtils +filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'float16', 'float32', 'bfloat16', 'int64', 'bool'}] + + +@triton.jit +def atomic_and(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): + in_offset = tl.program_id(0) * BLOCK_SIZE + out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE + in_index = in_offset + tl.arange(0, BLOCK_SIZE) + out_index = out_offset + tl.arange(0, BLOCK_SIZE) + xmask = in_index < n_elements + + tmp0 = tl.load(in_ptr0 + (in_index), xmask) + tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask) + + + +@triton.jit +def atomic_and_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + x = tl.load(x_ptr) # x is scalar or 1D, no mask needed + + # Compute y indices + y_offset = pid * BLOCK_SIZE + y_indices = y_offset + tl.arange(0, BLOCK_SIZE) + y_mask = y_indices < n_elements + + y_value = tl.load(y_ptr + y_indices, y_mask) + # Atomic or: y &= x (broadcasted) + tl.atomic_and(out_ptr + y_indices, y_value, mask=y_mask) + tl.atomic_and(out_ptr + y_indices, x, mask=y_mask) + + +# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) +test_cases = [ + ((1, 1, 1, 1), (1, 1, 1, 4), 4), + ((1, 1, 1, 3), (1, 5, 1, 3), 5), + ((3,), (2, 3, 3, 3, 3), 81), + ((3,), (2, 3, 3, 3), 27), + ((3,), (2, 3, 3), 9), + ((3,), (2, 3), 3), +] + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_and(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + # OR的时候任何位和0做OR都不变 任何位和1做AND也都不变,所以为了保持不变 不能用0 只能用1 + y = torch.full(shape, torch.iinfo(x_dtype).max, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + y_temp = y.clone() + + if len(shape) == 2: + n_elements = shape[0] * shape[1] * 2 + atomic_and[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) + elif len(shape) == 1: + n_elements = shape[0] + BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 + grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 + aligned_size = grid_size * BLOCK_SIZE + x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() + x_concat[0:n_elements] = x[0:n_elements] + x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] + atomic_and[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) + + expected = y_temp & x_temp[0:shape[0]] & x_temp[shape[0]:(shape[0] * 2)] + torch.testing.assert_close(y, expected) + + +# 3d +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_and_3d(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + y = torch.full(shape, 0, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + y_temp = y.clone() + + n_elements = shape[0] * shape[1] * shape[2] + atomic_and[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) + + expected = y_temp & x_temp[0:shape[0]] & x_temp[shape[0]:(shape[0] * 2)] + torch.testing.assert_close(y, expected) + + + +@pytest.mark.parametrize('shape', TestUtils.test_shape_ub_overflow) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +@test_common.raises_with_match(triton.compiler.errors.MLIRCompilationError, "ub overflow") +def test_atomic_and_ub_overflow(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + y = torch.full(shape, 0, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + y_temp = y.clone() + + n_elements = shape[0] * shape[1] * shape[2] + atomic_and[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) + + + +@triton.jit +def atomic_and_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tl.atomic_and(out_ptr0 + offsets, tmp0) + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('dtype', filtered_dtype) +def test_atomic_and_4d_5d(dtype, shape): + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() + x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() + + x1_ref = x1 & x0_value + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + atomic_and_multi_d[(1, )](x0, x1, *triton_shape) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@triton.jit +def atomic_and_5d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr, NB1: tl.constexpr): + base = tl.program_id(0) * (XB * YB * ZB * MB * NB) + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1 * NB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1 * NB1) + offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1 * NB1) + offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] * NB1 + offsets1 = offsets1[:, :, :, :, None] + tl.arange(0, NB1)[None, None, None, None, :] + + based_offsets = offsets + base + + tmp0 = tl.load(x_ptr + based_offsets) + tl.atomic_and(out_ptr + offsets1, tmp0) + + +@pytest.mark.parametrize('param_list', + [ + [(1, 1, 2, 1, 1), (1, 1, 2, 1, 2)], + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_and_5d(x_dtype_str, param_list): + x0_shape, y_shape = param_list + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(x0_shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + + out = torch.full(y_shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*x0_shape] + while len(triton_shape) < 5: + triton_shape.append(1) + XB, YB, ZB, MB, NB = triton_shape + + triton_shape1 = [*y_shape] + while len(triton_shape1) < 5: + triton_shape1.append(1) + XB1, YB1, ZB1, MB1, NB1 = triton_shape1 + + atomic_and_5d[(2, )]( + x_ptr=x, + out_ptr=out, + XB=XB, YB=YB, ZB=ZB, MB=MB, NB=NB, + XB1=XB1, YB1=YB1, ZB1=ZB1, MB1=MB1, NB1=NB1, + ) + + expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@triton.jit +def atomic_and_4d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr): + base = tl.program_id(0) * (XB * YB * ZB * MB) + offsets = tl.arange(0, XB) * (YB * ZB * MB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB) + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB) + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] + + offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1) + offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1) + offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] + + based_offsets = offsets + base + + tmp0 = tl.load(x_ptr + based_offsets) + tl.atomic_and(out_ptr + offsets1, tmp0) + + +@pytest.mark.parametrize('param_list', + [ + [(1, 1, 2, 1), (1, 1, 2, 2)], + [(1, 1, 1, 1), (1, 1, 2, 2)], + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_and_4d(x_dtype_str, param_list): + x0_shape, y_shape = param_list + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(x0_shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(y_shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*x0_shape] + while len(triton_shape) < 4: + triton_shape.append(1) + XB, YB, ZB, MB = triton_shape + + triton_shape1 = [*y_shape] + while len(triton_shape1) < 4: + triton_shape1.append(1) + XB1, YB1, ZB1, MB1 = triton_shape1 + + atomic_and_4d[(2, )]( + x_ptr=x, + out_ptr=out, + XB=XB, YB=YB, ZB=ZB, MB=MB, + XB1=XB1, YB1=YB1, ZB1=ZB1, MB1=MB1, + ) + + expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@triton.jit +def atomic_and_3d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr): + base = tl.program_id(0) * (XB * YB * ZB) + offsets = tl.arange(0, XB) * (YB * ZB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB) + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] + + offsets1 = tl.arange(0, XB1) * (YB1 * ZB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1) + offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] + + based_offsets = offsets + base + + tmp0 = tl.load(x_ptr + based_offsets) + tl.atomic_and(out_ptr + offsets1, tmp0) + + +@pytest.mark.parametrize('param_list', + [ + [(1, 1, 1), (1, 1, 2)], + [(1, 1, 2), (1, 2, 2)], + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_and_3d_2(x_dtype_str, param_list): + x0_shape, y_shape = param_list + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(x0_shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(y_shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*x0_shape] + while len(triton_shape) < 3: + triton_shape.append(1) + XB, YB, ZB = triton_shape + + triton_shape1 = [*y_shape] + while len(triton_shape1) < 3: + triton_shape1.append(1) + XB1, YB1, ZB1 = triton_shape1 + + atomic_and_3d[(2, )]( + x_ptr=x, + out_ptr=out, + XB=XB, YB=YB, ZB=ZB, + XB1=XB1, YB1=YB1, ZB1=ZB1, + ) + + expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@triton.jit +def atomic_and_2d(x_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr): + base = tl.program_id(0) * (XB * YB) + offsets = tl.arange(0, XB) * (YB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] + + offsets1 = tl.arange(0, XB1) * (YB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] + + based_offsets = offsets + base + + tmp0 = tl.load(x_ptr + based_offsets) + tl.atomic_and(out_ptr + offsets1, tmp0) + + +@pytest.mark.parametrize('param_list', + [ + [(1, 2), (2, 2)], + [(1, 1), (2, 2)], + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_and_2d(x_dtype_str, param_list): + x0_shape, y_shape = param_list + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(x0_shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(y_shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*x0_shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB, YB = triton_shape + + triton_shape1 = [*y_shape] + while len(triton_shape1) < 2: + triton_shape1.append(1) + XB1, YB1 = triton_shape1 + + atomic_and_2d[(2, )]( + x_ptr=x, + out_ptr=out, + XB=XB, YB=YB, + XB1=XB1, YB1=YB1, + ) + + expected = out_temp & x_temp[0:x0_shape[0]] & x_temp[x0_shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@triton.jit +def atomic_and(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, + BLOCK_NUM: tl.constexpr, mode: tl.constexpr = 0): + in_offset = tl.program_id(0) * BLOCK_SIZE + out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE + in_index = in_offset + tl.arange(0, BLOCK_SIZE) + out_index = out_offset + tl.arange(0, BLOCK_SIZE) + xmask = in_index < n_elements + + tmp0 = tl.load(in_ptr0 + (in_index), xmask) + if mode ==0: + tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask, 'acq_rel', 'cta') + elif mode == 1: + tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask, "test") + elif mode == 2: + tl.atomic_and(out_ptr0 + (out_index), tmp0, xmask, "acq_rel", "test") + + +invalid_types_int = [ + 'int64', + 'bool' +] + +@pytest.mark.parametrize("sigtype", invalid_types_int) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "All support dtypes are int8, int16, int32, float16, float32, bfloat16") +def test_invalid_types_int(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + atomic_and[1, 1, 1](x, y, 1, 1, 32) + + +invalid_types_float = [ + 'float16', + 'float32', + 'bfloat16' +] + +@pytest.mark.parametrize("sigtype", invalid_types_float) +@test_common.raises_with_match(triton.compiler.errors.MLIRCompilationError, "must be signless-integer-like") +def test_invalid_types_float(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + atomic_and[1, 1, 1](x, y, 1, 1, 32) + + +default_types = ['int8'] + +@pytest.mark.parametrize("sigtype", default_types) +@pytest.mark.parametrize("test_type", ["sem", "scope"]) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Memory semantic test not supported") +def test_invalid_sem_scope(sigtype, test_type): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + if test_type == "sem": + atomic_and[1, 1, 1](x, y, 1, 1, 32, 1) + elif test_type == "scope": + atomic_and[1, 1, 1](x, y, 1, 1, 32, 2) + + +@triton.jit +def _atomic_and_ss( + in_ptr, out_ptr, n_cols, + BLOCK_SIZE: tl.constexpr, + SEM: tl.constexpr, + SCOPE: tl.constexpr +): + pid = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = pid < n_cols + val = tl.load(in_ptr + pid, mask) + tl.atomic_and(out_ptr + pid, val, mask, sem=SEM, scope=SCOPE) + +SEMS = ("relaxed", "acquire", "release", "acq_rel") +SCOPES = ("cta", "gpu", "sys") + +@pytest.mark.parametrize("sem", SEMS) +@pytest.mark.parametrize("scope", SCOPES) +def test_atomic_sem_vs_scope(sem: str, scope: str): + n_cols = 1024 + BLOCK = 128 + grid = (triton.cdiv(n_cols, BLOCK),) + + inp = torch.full((n_cols,), 0xFF, dtype=torch.int32, device="npu") + + base = torch.full_like(inp, 0xFF) + _atomic_and_ss[grid](inp, base, n_cols, + BLOCK_SIZE=BLOCK, + SEM="acq_rel", + SCOPE="gpu") + + cur = torch.full_like(inp, 0xFF) + _atomic_and_ss[grid](inp, cur, n_cols, + BLOCK_SIZE=BLOCK, + SEM=sem, + SCOPE=scope) + + torch.testing.assert_close(cur, base) diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_cas.py b/third_party/ascend/examples/generalization_cases/test_atomic_cas.py new file mode 100644 index 000000000..ab14704d2 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_atomic_cas.py @@ -0,0 +1,393 @@ +import math +import pytest +import torch +import triton + +import triton.language as tl + +import test_common +from test_common import TestUtils +filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'bfloat16', 'int8', 'bool'}] + + +@triton.jit +def atomic_cas(in_ptr0, in_ptr1, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): + in_offset = tl.program_id(0) * BLOCK_SIZE + out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE + in_index = in_offset + tl.arange(0, BLOCK_SIZE) + out_index = out_offset + tl.arange(0, BLOCK_SIZE) + xmask = in_index < n_elements + + tmp0 = tl.load(in_ptr0 + (in_index), xmask) + tmp1 = tl.load(in_ptr1 + (in_index), xmask) + tl.atomic_cas(out_ptr0 + (out_index), tmp1, tmp0) + + +@triton.jit +def atomic_cas_ndim(x_ptr, y_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, + DIM0: tl.constexpr, DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): + sub_idx = tl.program_id(1) + base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE + base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE + offsets_src = tl.arange(0, BLOCK_SIZE) + base_src + offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst + mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 + tmp = tl.load(x_ptr + offsets_src, mask) + tmp_c = tl.load(y_ptr + offsets_src, mask) + tl.atomic_cas(out_ptr + offsets_dst, tmp_c, tmp) + + +@triton.jit +def atomic_cas_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + x = tl.load(x_ptr) # x is scalar or 1D, no mask needed + + # Compute y indices + y_offset = pid * BLOCK_SIZE + y_indices = y_offset + tl.arange(0, BLOCK_SIZE) + y_mask = y_indices < n_elements + + y_value = tl.load(y_ptr + y_indices, y_mask) + # Atomic or: y |= x (broadcasted) + tl.atomic_cas(out_ptr + y_indices, y_value, mask=y_mask) + tl.atomic_cas(out_ptr + y_indices, x, mask=y_mask) + + +# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) +test_cases = [ + ((1, 1, 1, 1), (1, 1, 1, 4), 4), + ((1, 1, 1, 3), (1, 5, 1, 3), 5), + ((3,), (2, 3, 3, 3, 3), 81), + ((3,), (2, 3, 3, 3), 27), + ((3,), (2, 3, 3), 9), + ((3,), (2, 3), 3), +] + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_cas(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + c = torch.randint(low=0, high=2, size=x_shape, dtype=x_dtype).npu() + y = torch.randint(low=0, high=2, size=shape, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + c_temp = c.clone() + y_temp = y.clone() + + if len(shape) == 2: + n_elements = shape[0] * shape[1] * 2 + atomic_cas[shape[0] * 2, 1, 1](x, c, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) + elif len(shape) == 1: + n_elements = shape[0] + BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 + grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 + aligned_size = grid_size * BLOCK_SIZE + # value + x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() + x_concat[0:n_elements] = x[0:n_elements] + x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] + # compare + c_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() + c_concat[0:n_elements] = c[0:n_elements] + c_concat[aligned_size:(aligned_size + n_elements)] = c[n_elements:(n_elements * 2)] + atomic_cas[grid_size * 2, 1, 1](x_concat, c_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) + + expected = torch.where(y_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], y_temp) + expected = torch.where(expected == c_temp[shape[0]:(shape[0] * 2)], x_temp[shape[0]:(shape[0] * 2)], expected) + torch.testing.assert_close(y, expected) + + +# 3d +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_cas_3d(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() + y = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + c_temp = c.clone() + y_temp = y.clone() + + n_elements = shape[0] * shape[1] * shape[2] + atomic_cas[2, 1, 1](x, c, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) + + expected = torch.where(y_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], y_temp) + expected = torch.where(expected == c_temp[shape[0]:(shape[0] * 2)], x_temp[shape[0]:(shape[0] * 2)], expected) + torch.testing.assert_close(y, expected) + + +@triton.jit +def atomic_cas_multi_d(in_ptr0, in_ptr1, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tmp1 = tl.load(in_ptr1 + offsets) + tl.atomic_cas(out_ptr0 + offsets, tmp1, tmp0) + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('dtype', filtered_dtype) +def test_atomic_cas_4d_5d(dtype, shape): + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() + c = torch.randint(low=2, high=4, size=shape, dtype=eval('torch.' + dtype)).npu() + x1 = torch.randint(low=2, high=4, size=shape, dtype=eval('torch.' + dtype)).npu() + + x1_ref = torch.where(x1 == c, 3, x1) + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + + atomic_cas_multi_d[(1, )](x0, c, x1, *triton_shape) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@pytest.mark.parametrize('shape', + [ + (1, 1, 1, 1, 2), + (10, 1, 15, 1, 7), + (1, 1, 1, 1, 257), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_cas_5d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + + c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() + out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() + + x_temp = x.clone() + c_temp = c.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + XB, YB, ZB, MB, NB = triton_shape + BLOCK_SIZE = 256 + ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_cas_ndim[(2 * XB * YB * ZB * MB, ncore)]( + x_ptr=x, + y_ptr=c, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=XB, DIM1=YB, DIM2=ZB, DIM3=MB, DIM4=NB, + ) + + expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) + expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', + [ + (1, 1, 1, 1), + (1, 1, 2, 2), + (1, 3, 2, 7), + (1, 3, 2, 651), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_cas_4d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() + out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() + + x_temp = x.clone() + c_temp = c.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 4: + triton_shape.append(1) + XB, YB, ZB, MB = triton_shape + + BLOCK_SIZE = 256 + ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_cas_ndim[(2 * XB * YB * ZB, ncore)]( + x_ptr=x, + y_ptr=c, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=XB, DIM2=YB, DIM3=ZB, DIM4=MB, + ) + + expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) + expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', + [ + (1, 1, 1), + (1, 1, 2), + (1, 31, 275), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_cas_3d_2(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() + out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() + + x_temp = x.clone() + c_temp = c.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 3: + triton_shape.append(1) + XB, YB, ZB = triton_shape + BLOCK_SIZE = 256 + ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_cas_ndim[(2 * XB * YB, ncore)]( + x_ptr=x, + y_ptr=c, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=XB, DIM3=YB, DIM4=ZB, + ) + + expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) + expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', + [ + (1, 2), + (1, 1), + (257, 1), + (257, 2), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_cas_2d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() + out = torch.randint(low=3, high=5, size=shape, dtype=x_dtype).npu() + + x_temp = x.clone() + c_temp = c.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB, YB = triton_shape + BLOCK_SIZE = 256 + ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_cas_ndim[(2 * XB, ncore)]( + x_ptr=x, + y_ptr=c, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=1, DIM3=XB, DIM4=YB, + ) + + expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) + expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', [(1,), (9,), (256,), (257,), (65535,), (65536,)]) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_cas_1d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + c = torch.randint(low=3, high=5, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + c_temp = c.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB = triton_shape[0] + BLOCK_SIZE = 256 + ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_cas_ndim[(2, ncore)]( + x_ptr=x, + y_ptr=c, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=1, DIM3=1, DIM4=XB, + ) + + expected = torch.where(out_temp == c_temp[0:shape[0]], x_temp[0:shape[0]], out_temp) + expected = torch.where(expected == c_temp[shape[0]:(x_shape[0])], x_temp[shape[0]:(x_shape[0])], expected) + torch.testing.assert_close(out, expected) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_max.py b/third_party/ascend/examples/generalization_cases/test_atomic_max.py new file mode 100644 index 000000000..25e6e0a55 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_atomic_max.py @@ -0,0 +1,237 @@ +import random +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils + +@triton.jit +def triton_test_fn_atomic_max_dma(in_ptr0, in_ptr1, out_ptr1, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + xoffset = tl.program_id(0) * BLOCK_SIZE + index = xoffset + tl.arange(0, BLOCK_SIZE)[:] + mask = index < n_elements + inp0 = tl.load(in_ptr0 + (index), mask) + inp1 = tl.load(in_ptr1 + (index), mask) + tmp1 = tl.atomic_max(out_ptr1 + (index), inp0, mask) + tmp2 = tl.atomic_max(out_ptr1 + (index), inp1, mask) + + +def promote_dtype(x_dtype, y_dtype): + """ + 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 + """ + # 如果两个数据类型一致,直接返回 + if x_dtype == y_dtype: + return y_dtype + + # 构建类型的优先级列表(从低到高) + priority = [ + torch.int8, torch.int16, torch.int32, + torch.float16, torch.bfloat16, torch.float32 + ] + + # 查找两种类型在优先级列表中的位置 + x_priority = priority.index(x_dtype) + y_priority = priority.index(y_dtype) + + # 如果y的优先级比x小,则提升到x的类型 + if y_priority < x_priority: + return x_dtype + else: + return y_dtype + + +# torch.max do not support int +@pytest.mark.parametrize('shape', random.sample(TestUtils.test_shape2d + TestUtils.test_shape1d, 5)) +@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +def test_atomic_max(x_dtype_str, y_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + x0 = test_common.generate_tensor(shape, x_dtype_str) + x1 = test_common.generate_tensor(shape, y_dtype_str) + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + out = torch.full(x1.shape, 0, dtype=out_dtype) + + out_ref = torch.maximum(out, x0) + out_ref = torch.maximum(out_ref, x1) + out_ref = out_ref.npu() + x0 = x0.npu() + x1 = x1.npu() + out = out.npu() + + if len(shape) == 2: + n_elements = shape[0] * shape[1] + triton_test_fn_atomic_max_dma[shape[0], 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=shape[1]) + elif len(shape) == 1: + n_elements = shape[0] + BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 + grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 + triton_test_fn_atomic_max_dma[grid_size, 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + torch.testing.assert_close(out, out_ref) + +# 3d +testlist = [ + (1,22,39), + (27,1,39), + (27,22,1), + (1,1,23), + (23,1,1), + (1,23,1), + (27,5,3), + (2,29,4), + (7,31,7), + (3,5,8), + (7,17,15), + (25,5,16), + (23,5,31), + (7,11,32), + (7,11,33), + (2,3,255), + (3,3,256), + (3,2,257), +] + + +@pytest.mark.parametrize('shape', random.sample(testlist, 5)) +@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +def test_atomic_max_3d(x_dtype_str, y_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + ncore = 1 + split_size = shape[0] // ncore + x0 = test_common.generate_tensor(shape, x_dtype_str) + x1 = test_common.generate_tensor(shape, y_dtype_str) + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + y = torch.full(shape, 0, dtype=out_dtype) + + out_ref = torch.full_like(x0, 0, dtype=out_dtype) + out_ref = torch.maximum(out_ref, x0) + out_ref = torch.maximum(out_ref, x1) + x0 = x0.npu() + x1 = x1.npu() + y = y.npu() + + n_elements = shape[0] * shape[1] * shape[2] + triton_test_fn_atomic_max_dma[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1] * shape[2]) + y = y.cpu() + torch.testing.assert_close(y, out_ref) + + +@triton.jit +def atomic_max_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tl.atomic_max(out_ptr0 + offsets, tmp0) + + +filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'int64', 'bool'}] + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('dtype', filtered_dtype) +def test_atomic_max_4d_5d(dtype, shape): + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() + x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() + + x1_ref = torch.maximum(x1, x0) + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + atomic_max_multi_d[(1, )](x0, x1, *triton_shape) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@triton.jit +def atomic_max_multi_d_2(in_ptr0, out_ptr0, out_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tmp1 = tl.load(out_ptr0 + offsets) + tl.atomic_max(out_ptr1 + offsets, tmp0) + tl.atomic_max(out_ptr1 + offsets, tmp1) + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +def test_atomic_max_4d_5d_2(x_dtype_str, y_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: + x0 = torch.randint(low=0, high=100, size=shape, dtype=x_dtype).npu() + else: + x0 = torch.randn(shape, dtype=eval('torch.' + x_dtype_str)).npu() + + if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: + x1 = torch.randint(low=0, high=100, size=shape, dtype=y_dtype).npu() + else: + x1 = torch.randn(shape, dtype=eval('torch.' + y_dtype_str)).npu() + + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: + y = torch.full(shape, torch.iinfo(out_dtype).min, dtype=out_dtype).npu() + else: + y = torch.full(shape, float('-inf'), dtype=out_dtype).npu() + + y_tmp = y + x1_ref = torch.maximum(y_tmp, x0) + x1_ref = torch.maximum(x1_ref, x1) + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + atomic_max_multi_d_2[(1, )](x0, x1, y, *triton_shape) + torch.testing.assert_close(y, x1_ref) diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_min.py b/third_party/ascend/examples/generalization_cases/test_atomic_min.py new file mode 100644 index 000000000..0e5e725fe --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_atomic_min.py @@ -0,0 +1,245 @@ +import random +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils + + +@triton.jit +def triton_test_fn_atomic_min_dma(in_ptr0, in_ptr1, out_ptr1, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + xoffset = tl.program_id(0) * BLOCK_SIZE + index = xoffset + tl.arange(0, BLOCK_SIZE)[:] + mask = index < n_elements + inp0 = tl.load(in_ptr0 + (index), mask) + inp1 = tl.load(in_ptr1 + (index), mask) + tmp1 = tl.atomic_min(out_ptr1 + (index), inp0, mask) + tmp2 = tl.atomic_min(out_ptr1 + (index), inp1, mask) + + +def promote_dtype(x_dtype, y_dtype): + """ + 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 + """ + # 如果两个数据类型一致,直接返回 + if x_dtype == y_dtype: + return y_dtype + + # 构建类型的优先级列表(从低到高) + priority = [ + torch.int8, torch.int16, torch.int32, + torch.float16, torch.bfloat16, torch.float32 + ] + + # 查找两种类型在优先级列表中的位置 + x_priority = priority.index(x_dtype) + y_priority = priority.index(y_dtype) + + # 如果y的优先级比x小,则提升到x的类型 + if y_priority < x_priority: + return x_dtype + else: + return y_dtype + + +# torch.min do not support int +@pytest.mark.parametrize('shape', random.sample(TestUtils.test_shape2d + TestUtils.test_shape1d, 5)) +@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +def test_atomic_min(x_dtype_str, y_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + x0 = test_common.generate_tensor(shape, x_dtype_str) + x1 = test_common.generate_tensor(shape, y_dtype_str) + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: # 判断是否是整数类型 + out = torch.full(x1.shape, torch.iinfo(out_dtype).max, dtype=out_dtype) + else: + out = torch.full(x1.shape, torch.finfo(out_dtype).max, dtype=out_dtype) + + out_ref = torch.minimum(out, x0) + out_ref = torch.minimum(out_ref, x1) + out_ref = out_ref.npu() + x0 = x0.npu() + x1 = x1.npu() + out = out.npu() + + if len(shape) == 2: + n_elements = shape[0] * shape[1] + triton_test_fn_atomic_min_dma[shape[0], 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=shape[1]) + elif len(shape) == 1: + n_elements = shape[0] + BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 + grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 + triton_test_fn_atomic_min_dma[grid_size, 1, 1](x0, x1, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + torch.testing.assert_close(out, out_ref) + + +# 3d +testlist = [ + (1,22,39), + (27,1,39), + (27,22,1), + (1,1,23), + (23,1,1), + (1,23,1), + (27,5,3), + (2,29,4), + (7,31,7), + (3,5,8), + (7,17,15), + (25,5,16), + (23,5,31), + (7,11,32), + (7,11,33), + (2,3,255), + (3,3,256), + (3,2,257), +] + + +@pytest.mark.parametrize('shape', random.sample(testlist, 5)) +@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +def test_atomic_min_3d(x_dtype_str, y_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + ncore = 1 + split_size = shape[0] // ncore + x0 = test_common.generate_tensor(shape, x_dtype_str) + x1 = test_common.generate_tensor(shape, y_dtype_str) + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: + y = torch.full(shape, torch.iinfo(out_dtype).max, dtype=out_dtype) + else: + y = torch.full(shape, float('inf'), dtype=out_dtype) + + y_tmp = y + x1_ref = torch.minimum(y_tmp, x0) + x1_ref = torch.minimum(x1_ref, x1) + x0 = x0.npu() + x1 = x1.npu() + y = y.npu() + + n_elements = shape[0] * shape[1] * shape[2] + triton_test_fn_atomic_min_dma[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1] * shape[2]) + y = y.cpu() + torch.testing.assert_close(y, x1_ref) + + +@triton.jit +def atomic_min_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tl.atomic_min(out_ptr0 + offsets, tmp0) + + +filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'int64', 'bool'}] + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('dtype', filtered_dtype) +def test_atomic_min_4d_5d(dtype, shape): + x0_value = 1 + x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() + x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() + + x1_ref = torch.minimum(x1, x0) + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + atomic_min_multi_d[(1, )](x0, x1, *triton_shape) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@triton.jit +def atomic_min_multi_d_2(in_ptr0, out_ptr0, out_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tmp1 = tl.load(out_ptr0 + offsets) + tl.atomic_min(out_ptr1 + offsets, tmp0) + tl.atomic_min(out_ptr1 + offsets, tmp1) + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('x_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +@pytest.mark.parametrize('y_dtype_str', ['float32', 'int32', 'int8', 'int16', 'bfloat16', 'float16']) +def test_atomic_min_4d_5d_2(x_dtype_str, y_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + if x_dtype == torch.int8 or x_dtype == torch.int16 or x_dtype == torch.int32: + x0 = torch.randint(low=0, high=100, size=shape, dtype=x_dtype).npu() + else: + x0 = torch.randn(shape, dtype=eval('torch.' + x_dtype_str)).npu() + + if y_dtype == torch.int8 or y_dtype == torch.int16 or y_dtype == torch.int32: + x1 = torch.randint(low=0, high=100, size=shape, dtype=y_dtype).npu() + else: + x1 = torch.randn(shape, dtype=eval('torch.' + y_dtype_str)).npu() + + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + if out_dtype == torch.int8 or out_dtype == torch.int16 or out_dtype == torch.int32: + y = torch.full(shape, torch.iinfo(out_dtype).max, dtype=out_dtype).npu() + else: + y = torch.full(shape, float('inf'), dtype=out_dtype).npu() + + y_tmp = y + x1_ref = torch.minimum(y_tmp, x0) + x1_ref = torch.minimum(x1_ref, x1) + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + atomic_min_multi_d_2[(1, )](x0, x1, y, *triton_shape) + torch.testing.assert_close(y, x1_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_or.py b/third_party/ascend/examples/generalization_cases/test_atomic_or.py new file mode 100644 index 000000000..579304cab --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_atomic_or.py @@ -0,0 +1,357 @@ +import math +import pytest +import torch +import triton + +import triton.language as tl + +import test_common +from test_common import TestUtils +filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'float16', 'float32', 'bfloat16', 'int64', 'bool'}] + + +@triton.jit +def atomic_or(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): + in_offset = tl.program_id(0) * BLOCK_SIZE + out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE + in_index = in_offset + tl.arange(0, BLOCK_SIZE) + out_index = out_offset + tl.arange(0, BLOCK_SIZE) + xmask = in_index < n_elements + + tmp0 = tl.load(in_ptr0 + (in_index), xmask) + tl.atomic_or(out_ptr0 + (out_index), tmp0, xmask) + + +@triton.jit +def atomic_or_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, + DIM0: tl.constexpr, DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): + sub_idx = tl.program_id(1) + base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE + base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE + offsets_src = tl.arange(0, BLOCK_SIZE) + base_src + offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst + mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 + tmp = tl.load(x_ptr + offsets_src, mask) + tl.atomic_or(out_ptr + offsets_dst, tmp, mask) + + +@triton.jit +def atomic_or_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + x = tl.load(x_ptr) # x is scalar or 1D, no mask needed + + # Compute y indices + y_offset = pid * BLOCK_SIZE + y_indices = y_offset + tl.arange(0, BLOCK_SIZE) + y_mask = y_indices < n_elements + + y_value = tl.load(y_ptr + y_indices, y_mask) + # Atomic or: y |= x (broadcasted) + tl.atomic_or(out_ptr + y_indices, y_value, mask=y_mask) + tl.atomic_or(out_ptr + y_indices, x, mask=y_mask) + + +# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) +test_cases = [ + ((1, 1, 1, 1), (1, 1, 1, 4), 4), + ((1, 1, 1, 3), (1, 5, 1, 3), 5), + ((3,), (2, 3, 3, 3, 3), 81), + ((3,), (2, 3, 3, 3), 27), + ((3,), (2, 3, 3), 9), + ((3,), (2, 3), 3), +] + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_or(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + y = torch.full(shape, 0, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + y_temp = y.clone() + + if len(shape) == 2: + n_elements = shape[0] * shape[1] * 2 + atomic_or[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) + elif len(shape) == 1: + n_elements = shape[0] + BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 + grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 + aligned_size = grid_size * BLOCK_SIZE + x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() + x_concat[0:n_elements] = x[0:n_elements] + x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] + atomic_or[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) + + expected = y_temp | x_temp[0:shape[0]] | x_temp[shape[0]:(shape[0] * 2)] + torch.testing.assert_close(y, expected) + + +# 3d +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_or_3d(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + y = torch.full(shape, 0, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + y_temp = y.clone() + + n_elements = shape[0] * shape[1] * shape[2] + atomic_or[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) + + expected = y_temp | x_temp[0:shape[0]] | x_temp[shape[0]:(shape[0] * 2)] + torch.testing.assert_close(y, expected) + + +@triton.jit +def atomic_or_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tl.atomic_or(out_ptr0 + offsets, tmp0) + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('dtype', filtered_dtype) +def test_atomic_or_4d_5d(dtype, shape): + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() + x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() + + x1_ref = x1 | x0_value + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + atomic_or_multi_d[(1, )](x0, x1, *triton_shape) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@pytest.mark.parametrize('shape', + [ + (1, 1, 1, 1, 2), + (10, 1, 15, 1, 7), + (1, 1, 1, 1, 257), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_or_5d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + XB, YB, ZB, MB, NB = triton_shape + BLOCK_SIZE = 256 + ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_or_ndim[(2 * XB * YB * ZB * MB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=XB, DIM1=YB, DIM2=ZB, DIM3=MB, DIM4=NB, + ) + + expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', + [ + (1, 1, 1, 1), + (1, 1, 2, 2), + (1, 3, 2, 7), + (1, 3, 2, 651), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_or_4d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 4: + triton_shape.append(1) + XB, YB, ZB, MB = triton_shape + + BLOCK_SIZE = 256 + ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_or_ndim[(2 * XB * YB * ZB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=XB, DIM2=YB, DIM3=ZB, DIM4=MB, + ) + + expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', + [ + (1, 1, 1), + (1, 1, 2), + (1, 31, 275), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_or_3d_2(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 3: + triton_shape.append(1) + XB, YB, ZB = triton_shape + BLOCK_SIZE = 256 + ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_or_ndim[(2 * XB * YB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=XB, DIM3=YB, DIM4=ZB, + ) + + expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', + [ + (1, 2), + (1, 1), + (257, 1), + (257, 2), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_or_2d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB, YB = triton_shape + BLOCK_SIZE = 256 + ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_or_ndim[(2 * XB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=1, DIM3=XB, DIM4=YB, + ) + + expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', [(1,), (9,), (256,), (257,), (65535,), (65536,)]) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_or_1d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB = triton_shape[0] + BLOCK_SIZE = 256 + ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_or_ndim[(2, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=1, DIM3=1, DIM4=XB, + ) + + expected = out_temp | x_temp[0:shape[0]] | x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_xchg.py b/third_party/ascend/examples/generalization_cases/test_atomic_xchg.py new file mode 100644 index 000000000..f2f2d5f48 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_atomic_xchg.py @@ -0,0 +1,357 @@ +import math +import pytest +import torch +import triton + +import triton.language as tl + +import test_common +from test_common import TestUtils +filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'bfloat16', 'int64', 'bool'}] + + +@triton.jit +def atomic_xchg(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): + in_offset = tl.program_id(0) * BLOCK_SIZE + out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE + in_index = in_offset + tl.arange(0, BLOCK_SIZE) + out_index = out_offset + tl.arange(0, BLOCK_SIZE) + xmask = in_index < n_elements + + tmp0 = tl.load(in_ptr0 + (in_index), xmask) + tl.atomic_xchg(out_ptr0 + (out_index), tmp0, xmask) + + +@triton.jit +def atomic_xchg_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, + DIM0: tl.constexpr, DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): + sub_idx = tl.program_id(1) + base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE + base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE + offsets_src = tl.arange(0, BLOCK_SIZE) + base_src + offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst + mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 + tmp = tl.load(x_ptr + offsets_src, mask) + tl.atomic_xchg(out_ptr + offsets_dst, tmp, mask) + + +@triton.jit +def atomic_xchg_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + x = tl.load(x_ptr) # x is scalar or 1D, no mask needed + + # Compute y indices + y_offset = pid * BLOCK_SIZE + y_indices = y_offset + tl.arange(0, BLOCK_SIZE) + y_mask = y_indices < n_elements + + y_value = tl.load(y_ptr + y_indices, y_mask) + # Atomic or: y |= x (broadcasted) + tl.atomic_xchg(out_ptr + y_indices, y_value, mask=y_mask) + tl.atomic_xchg(out_ptr + y_indices, x, mask=y_mask) + + +# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) +test_cases = [ + ((1, 1, 1, 1), (1, 1, 1, 4), 4), + ((1, 1, 1, 3), (1, 5, 1, 3), 5), + ((3,), (2, 3, 3, 3, 3), 81), + ((3,), (2, 3, 3, 3), 27), + ((3,), (2, 3, 3), 9), + ((3,), (2, 3), 3), +] + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xchg(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + y = torch.full(shape, 0, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + y_temp = y.clone() + + if len(shape) == 2: + n_elements = shape[0] * shape[1] * 2 + atomic_xchg[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) + elif len(shape) == 1: + n_elements = shape[0] + BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 + grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 + aligned_size = grid_size * BLOCK_SIZE + x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() + x_concat[0:n_elements] = x[0:n_elements] + x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] + atomic_xchg[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) + + expected = x_temp[shape[0]:(shape[0] * 2)].expand(y_temp.shape) + torch.testing.assert_close(y, expected) + + +# 3d +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xchg_3d(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + y = torch.full(shape, 0, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + y_temp = y.clone() + + n_elements = shape[0] * shape[1] * shape[2] + atomic_xchg[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) + + expected = x_temp[shape[0]:(shape[0] * 2)].expand(y_temp.shape) + torch.testing.assert_close(y, expected) + + +@triton.jit +def atomic_xchg_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tl.atomic_xchg(out_ptr0 + offsets, tmp0) + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('dtype', filtered_dtype) +def test_atomic_xchg_4d_5d(dtype, shape): + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() + x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() + + x1_ref = x0 + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + atomic_xchg_multi_d[(1, )](x0, x1, *triton_shape) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@pytest.mark.parametrize('shaape', + [ + (1, 1, 1, 1, 2), + (10, 1, 15, 1, 7), + (1, 1, 1, 1, 257), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xchg_5d(x_dtype_str, shaape): + shape = shaape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + XB, YB, ZB, MB, NB = triton_shape + BLOCK_SIZE = 256 + ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xchg_ndim[(2 * XB * YB * ZB * MB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=XB, DIM1=YB, DIM2=ZB, DIM3=MB, DIM4=NB, + ) + + expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shaape', + [ + (1, 1, 1, 1), + (1, 1, 2, 2), + (1, 3, 2, 7), + (1, 3, 2, 651), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xchg_4d(x_dtype_str, shaape): + shape = shaape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 4: + triton_shape.append(1) + XB, YB, ZB, MB = triton_shape + + BLOCK_SIZE = 256 + ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xchg_ndim[(2 * XB * YB * ZB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=XB, DIM2=YB, DIM3=ZB, DIM4=MB, + ) + + expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shaape', + [ + (1, 1, 1), + (1, 1, 2), + (1, 31, 275), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xchg_3d_2(x_dtype_str, shaape): + shape = shaape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 3: + triton_shape.append(1) + XB, YB, ZB = triton_shape + BLOCK_SIZE = 256 + ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xchg_ndim[(2 * XB * YB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=XB, DIM3=YB, DIM4=ZB, + ) + + expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shaape', + [ + (1, 2), + (1, 1), + (257, 1), + (257, 2), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xchg_2d(x_dtype_str, shaape): + shape = shaape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB, YB = triton_shape + BLOCK_SIZE = 256 + ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xchg_ndim[(2 * XB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=1, DIM3=XB, DIM4=YB, + ) + + expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shaape', [(1,), (9,), (256,), (257,), (65535,), (65536,)]) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xchg_1d(x_dtype_str, shaape): + shape = shaape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB = triton_shape[0] + BLOCK_SIZE = 256 + ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xchg_ndim[(2, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=1, DIM3=1, DIM4=XB, + ) + + expected = x_temp[shape[0]:x_shape[0]].expand(out_temp.shape) + torch.testing.assert_close(out, expected) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_xor.py b/third_party/ascend/examples/generalization_cases/test_atomic_xor.py new file mode 100644 index 000000000..9018fea93 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_atomic_xor.py @@ -0,0 +1,360 @@ +import math +import pytest +import torch +import triton + +import triton.language as tl + +import test_common +from test_common import TestUtils +filtered_dtype = [dtype for dtype in TestUtils.full_dtype if dtype not in {'uint32', 'float16', 'float32', 'bfloat16', 'int64', 'bool'}] + + +@triton.jit +def atomic_xor(in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr, BLOCK_NUM: tl.constexpr): + in_offset = tl.program_id(0) * BLOCK_SIZE + out_offset = (tl.program_id(0) % BLOCK_NUM) * BLOCK_SIZE + in_index = in_offset + tl.arange(0, BLOCK_SIZE) + out_index = out_offset + tl.arange(0, BLOCK_SIZE) + xmask = in_index < n_elements + + tmp0 = tl.load(in_ptr0 + (in_index), xmask) + tl.atomic_xor(out_ptr0 + (out_index), tmp0, xmask) + + +@triton.jit +def atomic_xor_ndim(x_ptr, out_ptr, NCORE: tl.constexpr, BLOCK_SIZE: tl.constexpr, + DIM0: tl.constexpr, DIM1: tl.constexpr, DIM2: tl.constexpr, DIM3: tl.constexpr, DIM4: tl.constexpr): + sub_idx = tl.program_id(1) + base_src = tl.program_id(0) * DIM4 + sub_idx * BLOCK_SIZE + base_dst = (tl.program_id(0) % (DIM0 * DIM1 * DIM2 * DIM3)) * DIM4 + sub_idx * BLOCK_SIZE + offsets_src = tl.arange(0, BLOCK_SIZE) + base_src + offsets_dst = tl.arange(0, BLOCK_SIZE) + base_dst + mask = tl.arange(0, BLOCK_SIZE) + sub_idx * BLOCK_SIZE < DIM4 + tmp = tl.load(x_ptr + offsets_src, mask) + tl.atomic_xor(out_ptr + offsets_dst, tmp, mask) + + +@triton.jit +def atomic_xor_broadcast(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + x = tl.load(x_ptr) # x is scalar or 1D, no mask needed + + # Compute y indices + y_offset = pid * BLOCK_SIZE + y_indices = y_offset + tl.arange(0, BLOCK_SIZE) + y_mask = y_indices < n_elements + + y_value = tl.load(y_ptr + y_indices, y_mask) + # Atomic or: y |= x (broadcasted) + tl.atomic_xor(out_ptr + y_indices, y_value, mask=y_mask) + tl.atomic_xor(out_ptr + y_indices, x, mask=y_mask) + + +# 定义不同测试场景的参数组合 (x_shape, y_shape, BLOCK_SIZE) +test_cases = [ + ((1, 1, 1, 1), (1, 1, 1, 4), 4), + ((1, 1, 1, 3), (1, 5, 1, 3), 5), + ((3,), (2, 3, 3, 3, 3), 81), + ((3,), (2, 3, 3, 3), 27), + ((3,), (2, 3, 3), 9), + ((3,), (2, 3), 3), +] + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d + TestUtils.test_shape1d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xor(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + if len(shape) == 1 and shape[0] == 1: # golden 问题,手动验证 + return + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + y = torch.full(shape, 0, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + y_temp = y.clone() + + if len(shape) == 2: + n_elements = shape[0] * shape[1] * 2 + atomic_xor[shape[0] * 2, 1, 1](x, y, n_elements, BLOCK_SIZE=shape[1], BLOCK_NUM=shape[0]) + elif len(shape) == 1: + n_elements = shape[0] + BLOCK_SIZE = min(1024, shape[0]) # 1024:限制最大线程块大小 + grid_size = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE # 向上取整 + aligned_size = grid_size * BLOCK_SIZE + x_concat = torch.full([aligned_size * 2], 0, dtype=x_dtype).npu() + x_concat[0:n_elements] = x[0:n_elements] + x_concat[aligned_size:(aligned_size + n_elements)] = x[n_elements:(n_elements * 2)] + atomic_xor[grid_size * 2, 1, 1](x_concat, y, aligned_size * 2, BLOCK_SIZE=BLOCK_SIZE, BLOCK_NUM=grid_size) + + expected = y_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:(shape[0] * 2)] + torch.testing.assert_close(y, expected) + + +# 3d +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xor_3d(x_dtype_str, shape): + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + y = torch.full(shape, 0, dtype=x_dtype).npu() + + # 保存副本用于验证 + x_temp = x.clone() + y_temp = y.clone() + + n_elements = shape[0] * shape[1] * shape[2] + atomic_xor[2, 1, 1](x, y, n_elements * 2, BLOCK_SIZE=shape[0] * shape[1] * shape[2], BLOCK_NUM=1) + + expected = y_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:(shape[0] * 2)] + torch.testing.assert_close(y, expected) + + +@triton.jit +def atomic_xor_multi_d(in_ptr0, out_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + tmp0 = tl.load(in_ptr0 + offsets) + tl.atomic_xor(out_ptr0 + offsets, tmp0) + + +# multi_d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 8, 4), + (8, 4, 2, 4), + (2, 8, 2, 2), + (2, 4, 8, 4, 2), + (8, 4, 2, 4, 4), + (2, 8, 2, 2, 2), +]) +@pytest.mark.parametrize('dtype', filtered_dtype) +def test_atomic_xor_4d_5d(dtype, shape): + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() + x1 = torch.full(shape, 2, dtype=eval('torch.' + dtype)).npu() + + x1_ref = x1 ^ x0_value + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + atomic_xor_multi_d[(1, )](x0, x1, *triton_shape) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@pytest.mark.parametrize('shape', + [ + (1, 1, 1, 1, 2), + (10, 1, 15, 1, 7), + (1, 1, 1, 1, 257), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xor_5d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + XB, YB, ZB, MB, NB = triton_shape + BLOCK_SIZE = 256 + ncore = (NB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xor_ndim[(2 * XB * YB * ZB * MB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=XB, DIM1=YB, DIM2=ZB, DIM3=MB, DIM4=NB, + ) + + expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', + [ + (1, 1, 1, 1), + (1, 1, 2, 2), + (1, 3, 2, 7), + (1, 3, 2, 651), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xor_4d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 4: + triton_shape.append(1) + XB, YB, ZB, MB = triton_shape + + BLOCK_SIZE = 256 + ncore = (MB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xor_ndim[(2 * XB * YB * ZB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=XB, DIM2=YB, DIM3=ZB, DIM4=MB, + ) + + expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', + [ + (1, 1, 1), + (1, 1, 2), + (1, 31, 275), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xor_3d_2(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 3: + triton_shape.append(1) + XB, YB, ZB = triton_shape + BLOCK_SIZE = 256 + ncore = (ZB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xor_ndim[(2 * XB * YB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=XB, DIM3=YB, DIM4=ZB, + ) + + expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', + [ + (1, 2), + (1, 1), + (257, 1), + (257, 2), + ] + ) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xor_2d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB, YB = triton_shape + BLOCK_SIZE = 256 + ncore = (YB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xor_ndim[(2 * XB, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=1, DIM3=XB, DIM4=YB, + ) + + expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize('shape', [(1,), (9,), (256,), (257,), (65535,), (65536,)]) +@pytest.mark.parametrize('x_dtype_str', filtered_dtype) +def test_atomic_xor_1d(x_dtype_str, shape): + shape = shape + + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + x_shape = list(shape[:]) + x_shape[0] *= 2 + x = torch.randint(low=0, high=100, size=x_shape, dtype=x_dtype).npu() + out = torch.full(shape, 0, dtype=x_dtype).npu() + + x_temp = x.clone() + out_temp = out.clone() + + triton_shape = [*shape] + while len(triton_shape) < 2: + triton_shape.append(1) + XB = triton_shape[0] + BLOCK_SIZE = 256 + ncore = (XB + BLOCK_SIZE - 1) // BLOCK_SIZE + + atomic_xor_ndim[(2, ncore)]( + x_ptr=x, + out_ptr=out, + NCORE=ncore, + BLOCK_SIZE=BLOCK_SIZE, + DIM0=1, DIM1=1, DIM2=1, DIM3=1, DIM4=XB, + ) + + expected = out_temp ^ x_temp[0:shape[0]] ^ x_temp[shape[0]:x_shape[0]] + torch.testing.assert_close(out, expected) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_broadcast.py b/third_party/ascend/examples/generalization_cases/test_broadcast.py new file mode 100644 index 000000000..a5dab9d70 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_broadcast.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils + + +@triton.jit +def fn_broadcast_1d(output_ptr, x_ptr, XS: tl.constexpr, YS: tl.constexpr): + xidx = tl.arange(0, XS)[None, :] + base = tl.load(x_ptr + xidx) + out = base.broadcast_to((YS, XS)) + oidx = tl.arange(0, YS)[:, None] * XS + tl.arange(0, XS)[None, :] + tl.store(output_ptr + oidx, out) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_npu_1d(shape, dtype): + XS = shape[0] + YS = 4 + + x = test_common.generate_tensor((XS, ), dtype=dtype).npu() + std = torch.broadcast_to(x, (YS, XS)) + output = test_common.generate_tensor((YS, XS), dtype=dtype).npu() + fn_broadcast_1d[1, 1, 1](output, x, XS, YS) + test_common.validate_cmp(dtype, std, output) + + +@triton.jit +def fn_broadcast_2d(output_ptr, x_ptr, NUMEL: tl.constexpr, XS: tl.constexpr, YS: tl.constexpr, ZS: tl.constexpr): + zoffset = tl.program_id(0) * ZS + zidx = tl.arange(0, ZS)[None, :] + base = tl.load(x_ptr + zoffset + zidx) + out = base.broadcast_to((YS, ZS)) + oidx = zoffset * YS + tl.arange(0, YS)[:, None] * ZS + zidx + tl.store(output_ptr + oidx, out) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_npu_2d(shape, dtype): + XS = shape[0] + ZS = shape[1] + YS = 4 + NUMEL = XS * ZS + + x = test_common.generate_tensor((XS, 1, ZS), dtype=dtype).npu() # randn not support int type + std = torch.broadcast_to(x, (XS, YS, ZS)) + output = test_common.generate_tensor((XS, YS, ZS), dtype=dtype).npu() + fn_broadcast_2d[XS, 1, 1](output, x, NUMEL, XS, YS, ZS) + test_common.validate_cmp(dtype, std, output) + + +@triton.jit +def triton_broadcast_to_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x = tl.load(in_ptr0 + idx) + x1 = tl.load(out_ptr0 + odx) + ret = tl.broadcast(x, x1) + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim0(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(1, M, N), dtype=dtype).npu() + ans = x0.repeat(L, 1, 1) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim0[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x = tl.load(in_ptr0 + idx) + x1 = tl.load(out_ptr0 + odx) + ret = tl.broadcast(x, x1) + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim1(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(L, 1, N), dtype=dtype).npu() + ans = x0.repeat(1, M, 1) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim1[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * 1 * M + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x = tl.load(in_ptr0 + idx) + x1 = tl.load(out_ptr0 + odx) + ret = tl.broadcast(x, x1) + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim2(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(L, M, 1), dtype=dtype).npu() + ans = x0.repeat(1, 1, N) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim2[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim01(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = tl.arange(0, 1)[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x = tl.load(in_ptr0 + idx) + x1 = tl.load(out_ptr0 + odx) + ret = tl.broadcast(x, x1) + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim01(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(1, 1, N), dtype=dtype).npu() + ans = x0.repeat(L, M, 1) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim01[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim02(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = tl.arange(0, 1)[:, None, None] * M * 1 + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x = tl.load(in_ptr0 + idx) + x1 = tl.load(out_ptr0 + odx) + ret = tl.broadcast(x, x1) + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim02(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(1, M, 1), dtype=dtype).npu() + ans = x0.repeat(L, 1, N) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim02[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim12(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * 1 * 1 + tl.arange(0, 1)[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x = tl.load(in_ptr0 + idx) + x1 = tl.load(out_ptr0 + odx) + ret = tl.broadcast(x, x1) + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim12(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(L, 1, 1), dtype=dtype).npu() + ans = x0.repeat(1, M, N) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim12[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def fn_broadcast_multi_d(to_ptr, from_ptr, F_L: tl.constexpr, F_M: tl.constexpr, F_N: tl.constexpr, F_X: tl.constexpr, F_Y: tl.constexpr, T_L: tl.constexpr, T_M: tl.constexpr, T_N: tl.constexpr, T_X: tl.constexpr, T_Y: tl.constexpr): + from_offsets = tl.arange(0, F_L) + if F_M is not None: + from_offsets = from_offsets[:, None] * F_M + tl.arange(0, F_M)[None, :] + if F_N is not None: + from_offsets = from_offsets[:, :, None] * F_N + tl.arange(0, F_N)[None, None, :] + if F_X is not None: + from_offsets = from_offsets[:, :, :, None] * F_X + tl.arange(0, F_X)[None, None, None, :] + if F_Y is not None: + from_offsets = from_offsets[:, :, :, :, None] * F_Y + tl.arange(0, F_Y)[None, None, None, None, :] + + to_offsets = tl.arange(0, T_L) + if T_M is not None: + to_offsets = to_offsets[:, None] * T_M + tl.arange(0, T_M)[None, :] + if T_N is not None: + to_offsets = to_offsets[:, :, None] * T_N + tl.arange(0, T_N)[None, None, :] + if T_X is not None: + to_offsets = to_offsets[:, :, :, None] * T_X + tl.arange(0, T_X)[None, None, None, :] + if T_Y is not None: + to_offsets = to_offsets[:, :, :, :, None] * T_Y + tl.arange(0, T_Y)[None, None, None, None, :] + + from_data = tl.load(from_ptr + from_offsets) + to_data = tl.load(to_ptr + to_offsets) + ret_data = tl.broadcast(from_data, to_data) + + tl.store(to_ptr + to_offsets, ret_data) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shapes', [ + [(1, 64, 16, 1), (2, 64, 16, 2)], + [(8, 1, 1, 2), (8, 8, 4, 2)], +]) +@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) +def test_broadcast_to_4d(shapes, dtype): + from_shape, to_shape = shapes + dtype = eval(f"torch.{dtype}") + + x = torch.randint(0, 8, from_shape, dtype=dtype).npu() + y = torch.randint(0, 8, to_shape, dtype=dtype).npu() + expected = x.expand(to_shape) + + grid = (1, ) + triton_from_shape = [*from_shape] + triton_to_shape = [*to_shape] + while len(triton_from_shape) < 5: + triton_from_shape.append(None) + triton_to_shape.append(None) + fn_broadcast_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) + assert(torch.equal(y, expected)) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) +@pytest.mark.parametrize('shapes', [ + [(1, 4, 2, 1, 4), (2, 4, 2, 8, 4)], + [(3, 1, 2, 1, 4), (3, 4, 2, 8, 4)], +]) +def test_broadcast_to_5d(shapes, dtype): + from_shape, to_shape = shapes + dtype = eval(f"torch.{dtype}") + + x = torch.randint(0, 8, from_shape, dtype=dtype).npu() + y = torch.randint(0, 8, to_shape, dtype=dtype).npu() + expected = x.expand(to_shape) + + grid = (1, ) + triton_from_shape = [*from_shape] + triton_to_shape = [*to_shape] + while len(triton_from_shape) < 5: + triton_from_shape.append(None) + triton_to_shape.append(None) + fn_broadcast_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) + assert(torch.equal(y, expected)) diff --git a/third_party/ascend/examples/generalization_cases/test_broadcast_to.py b/third_party/ascend/examples/generalization_cases/test_broadcast_to.py new file mode 100644 index 000000000..f5a0c3b37 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_broadcast_to.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common +from test_common import TestUtils + +@triton.jit +def fn_broadcast_1d(output_ptr, x_ptr, XS: tl.constexpr, YS: tl.constexpr): + xidx = tl.arange(0, XS)[None, :] + base = tl.load(x_ptr + xidx) + out = base.broadcast_to((YS, XS)) + oidx = tl.arange(0, YS)[:, None] * XS + tl.arange(0, XS)[None, :] + tl.store(output_ptr + oidx, out) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_npu_1d(shape, dtype): + XS = shape[0] + YS = 4 + + x = test_common.generate_tensor((XS, ), dtype=dtype).npu() + std = torch.broadcast_to(x, (YS, XS)) + output = test_common.generate_tensor((YS, XS), dtype=dtype).npu() + fn_broadcast_1d[1, 1, 1](output, x, XS, YS) + test_common.validate_cmp(dtype, std, output) + + +@triton.jit +def fn_broadcast_2d(output_ptr, x_ptr, NUMEL:tl.constexpr, XS: tl.constexpr, YS: tl.constexpr, ZS: tl.constexpr): + zoffset = tl.program_id(0) * ZS + zidx = tl.arange(0, ZS)[None, :] + base = tl.load(x_ptr + zoffset + zidx) + out = base.broadcast_to((YS, ZS)) + oidx = zoffset * YS + tl.arange(0, YS)[:, None] * ZS + zidx + tl.store(output_ptr + oidx, out) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_npu_2d(shape, dtype): + XS = shape[0] + ZS = shape[1] + YS = 4 + NUMEL = XS * ZS + + x = test_common.generate_tensor((XS, 1, ZS), dtype=dtype).npu() + std = torch.broadcast_to(x, (XS, YS, ZS)) + output = test_common.generate_tensor((XS, YS, ZS), dtype=dtype).npu() + fn_broadcast_2d[XS, 1, 1](output, x, NUMEL, XS, YS, ZS) + test_common.validate_cmp(dtype, std, output) + + +@triton.jit +def triton_broadcast_to_dim0(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = tl.arange(0, 1)[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x = tl.load(in_ptr0 + idx) + ret = x.broadcast_to(L, M, N) + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim0(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(1, M, N), dtype=dtype).npu() + ans = x0.repeat(L, 1, 1) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim0[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim1(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] + x = tl.load(in_ptr0 + idx) + ret = x.broadcast_to(L, M, N) + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim1(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(L, 1, N), dtype=dtype).npu() + ans = x0.repeat(1, M, 1) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim1[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim2(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * 1 * M + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] + x = tl.load(in_ptr0 + idx) + ret = x.broadcast_to(L, M, N) + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim2(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(L, M, 1), dtype=dtype).npu() + ans = x0.repeat(1, 1, N) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim2[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim01(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = tl.arange(0, 1)[:, None, None] * N * 1 + tl.arange(0, 1)[None, :, None] * N + nblk_idx[None, None, :] + x = tl.load(in_ptr0 + idx) + ret = x.broadcast_to(L, M, N) + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim01(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(1, 1, N), dtype=dtype).npu() + ans = x0.repeat(L, M, 1) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim01[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim02(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = tl.arange(0, 1)[:, None, None] * M * 1 + mblk_idx[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] + x = tl.load(in_ptr0 + idx) + ret = x.broadcast_to(L, M, N) + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim02(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(1, M, 1), dtype=dtype).npu() + ans = x0.repeat(L, 1, N) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim02[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def triton_broadcast_to_dim12(in_ptr0, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * 1 * 1 + tl.arange(0, 1)[None, :, None] * 1 + tl.arange(0, 1)[None, None, :] + x = tl.load(in_ptr0 + idx) + ret = x.broadcast_to(L, M, N) + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_broadcast_to_dim12(shape, dtype): + L, M, N = shape + x0 = test_common.generate_tensor(shape=(L, 1, 1), dtype=dtype).npu() + ans = x0.repeat(1, M, N) + output = torch.zeros((L, M, N), dtype=eval('torch.' + dtype)).npu() + triton_broadcast_to_dim12[1, 1, 1](x0, output, L, M, N) + test_common.validate_cmp(dtype, output, ans) + + +@triton.jit +def fn_broadcast_to_multi_d(to_ptr, from_ptr, F_L: tl.constexpr, F_M: tl.constexpr, F_N: tl.constexpr, F_X: tl.constexpr, F_Y: tl.constexpr, T_L: tl.constexpr, T_M: tl.constexpr, T_N: tl.constexpr, T_X: tl.constexpr, T_Y: tl.constexpr): + from_offsets = tl.arange(0, F_L) + if F_M is not None: + from_offsets = from_offsets[:, None] * F_M + tl.arange(0, F_M)[None, :] + if F_N is not None: + from_offsets = from_offsets[:, :, None] * F_N + tl.arange(0, F_N)[None, None, :] + if F_X is not None: + from_offsets = from_offsets[:, :, :, None] * F_X + tl.arange(0, F_X)[None, None, None, :] + if F_Y is not None: + from_offsets = from_offsets[:, :, :, :, None] * F_Y + tl.arange(0, F_Y)[None, None, None, None, :] + + to_offsets = tl.arange(0, T_L) + if T_M is not None: + to_offsets = to_offsets[:, None] * T_M + tl.arange(0, T_M)[None, :] + if T_N is not None: + to_offsets = to_offsets[:, :, None] * T_N + tl.arange(0, T_N)[None, None, :] + if T_X is not None: + to_offsets = to_offsets[:, :, :, None] * T_X + tl.arange(0, T_X)[None, None, None, :] + if T_Y is not None: + to_offsets = to_offsets[:, :, :, :, None] * T_Y + tl.arange(0, T_Y)[None, None, None, None, :] + + from_data = tl.load(from_ptr + from_offsets) + if F_Y is not None: + ret_data = from_data.broadcast_to((T_L, T_M, T_N, T_X, T_Y)) + elif F_X is not None: + ret_data = from_data.broadcast_to((T_L, T_M, T_N, T_X)) + elif F_N is not None: + ret_data = from_data.broadcast_to((T_L, T_M, T_N)) + elif F_M is not None: + ret_data = from_data.broadcast_to((T_L, T_M)) + else: + ret_data = from_data.broadcast_to((T_L)) + + tl.store(to_ptr + to_offsets, ret_data) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shapes', [ + [(1, 64, 16, 1), (2, 64, 16, 2)], + [(8, 1, 1, 2), (8, 8, 4, 2)], +]) +@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) +def test_broadcast_to_4d(shapes, dtype): + from_shape, to_shape = shapes + dtype = eval(f"torch.{dtype}") + + x = torch.randint(0, 8, from_shape, dtype=dtype).npu() + y = torch.randint(0, 8, to_shape, dtype=dtype).npu() + expected = x.expand(to_shape) + + grid = (1, ) + triton_from_shape = [*from_shape] + triton_to_shape = [*to_shape] + while len(triton_from_shape) < 5: + triton_from_shape.append(None) + triton_to_shape.append(None) + fn_broadcast_to_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) + assert(torch.equal(y, expected)) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('dtype', ["int32", "int64", "float16", "float32", "bfloat16"]) +@pytest.mark.parametrize('shapes', [ + [(1, 4, 2, 1, 4), (2, 4, 2, 8, 4)], + [(3, 1, 2, 1, 4), (3, 4, 2, 8, 4)], +]) +def test_broadcast_to_5d(shapes, dtype): + from_shape, to_shape = shapes + dtype = eval(f"torch.{dtype}") + + x = torch.randint(0, 8, from_shape, dtype=dtype).npu() + y = torch.randint(0, 8, to_shape, dtype=dtype).npu() + expected = x.expand(to_shape) + + grid = (1, ) + triton_from_shape = [*from_shape] + triton_to_shape = [*to_shape] + while len(triton_from_shape) < 5: + triton_from_shape.append(None) + triton_to_shape.append(None) + fn_broadcast_to_multi_d[grid](y, x, *triton_from_shape, *triton_to_shape) + assert(torch.equal(y, expected)) diff --git a/third_party/ascend/examples/generalization_cases/test_cast.py b/third_party/ascend/examples/generalization_cases/test_cast.py new file mode 100644 index 000000000..90c02e80b --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_cast.py @@ -0,0 +1,378 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import math +import random +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size + +@triton.jit +def cast_to_bool(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.int1) + tl.store(output_ptr + idx, ret) + +@triton.jit +def cast_to_i8(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.int8) + tl.store(output_ptr + idx, ret) + +@triton.jit +def cast_to_i16(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.int16) + tl.store(output_ptr + idx, ret) + +@triton.jit +def cast_to_i32(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.int32) + tl.store(output_ptr + idx, ret) + +@triton.jit +def cast_to_i64(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.int64) + tl.store(output_ptr + idx, ret) + +@triton.jit +def cast_to_fp32(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.float32) + tl.store(output_ptr + idx, ret) + + +@triton.jit +def cast_to_fp16(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.float16) + tl.store(output_ptr + idx, ret) + + +@triton.jit +def cast_to_bf16(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.bfloat16) + tl.store(output_ptr + idx, ret) + +@triton.jit +def cast_to_uint32(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.uint32) + tl.store(output_ptr + idx, ret) + +@triton.jit +def cast_to_int64(output_ptr, x_ptr, x_stride, y_stride, z_stride, + DIM: tl.constexpr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + if DIM == 1: + xidx = tl.arange(0, XB) + idx = xidx * x_stride + elif DIM == 2: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + idx = xidx[:, None] * x_stride + yidx[None, :] * y_stride + elif DIM == 3: + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * x_stride + yidx[None, :, None] * y_stride + zidx[None, None, :] * z_stride + + X = tl.load(x_ptr + idx) + ret = tl.cast(X, dtype=tl.int64) + tl.store(output_ptr + idx, ret) + +triton_func_map = { + "bool": cast_to_bool, + "int8": cast_to_i8, + "int16": cast_to_i16, + "int32": cast_to_i32, + "float16": cast_to_fp16, + "bfloat16": cast_to_bf16, + "float32": cast_to_fp32, + "uint32": cast_to_uint32, + "int64": cast_to_int64 +} + +def structParam(x0): + dim = x0.dim() + stride0, stride1, stride2 = 0, 0, 0 + shape0, shape1, shape2 = 0, 0, 0 + if dim >= 1: + stride0 = x0.stride(0) + shape0 = x0.shape[0] + if dim >= 2: + stride1 = x0.stride(1) + shape1 = x0.shape[1] + if dim == 3: + stride2 = x0.stride(2) + shape2 = x0.shape[2] + return dim, stride0, stride1, stride2, shape0, shape1, shape2 + + +@pytest.mark.parametrize('shape', random.sample(TestUtils.full_shape, 5)) +@pytest.mark.parametrize('srcDtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dstDtype', TestUtils.full_dtype) +def test_cast(srcDtype, dstDtype, shape): + if srcDtype == dstDtype: + return + srcBytes = get_dtype_size(srcDtype) + dstBytes = get_dtype_size(dstDtype) + dtype_size = max(srcBytes, dstBytes) + if dstDtype == 'int8': + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 100): + print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") + return + elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 12): + print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") + return + + x0 = test_common.generate_tensor(shape, srcDtype) + torch_res = x0.to(eval("torch." + dstDtype)) + x0 = x0.npu() + triton_func = triton_func_map.get(dstDtype, None) + assert triton_func is not None, f"triton_func not Found, srcDtype:{srcDtype}, dstDtype:{dstDtype}" + triton_res = torch.empty(shape, dtype=eval("torch." + dstDtype)).npu() + dim, stride0, stride1, stride2, XB, YB, ZB = structParam(x0) + assert 0 <= dim <= 3, f"dim out of range [0, 3], dim:{dim}" + triton_func[1, 1, 1](triton_res, x0, stride0, stride1, stride2, dim, XB, YB, ZB) + test_common.validate_cmp(dstDtype, triton_res, torch_res) + + +@triton.jit +def cast_to_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + dtype = output_ptr.type.element_ty + + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + X = tl.load(x_ptr + offsets) + ret = tl.cast(X, dtype=dtype) + + tl.store(output_ptr + offsets, ret) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (6, 2, 4, 2), + (4, 2, 8, 4), + (4, 3, 8, 4), +]) +@pytest.mark.parametrize('srcDtype', + ['int8', 'float16', 'float32'] +) +@pytest.mark.parametrize('dstDtype', + ['int8', 'float16', 'float32'] +) +def test_cast_4d(srcDtype, dstDtype, shape): + if srcDtype == dstDtype: + return + srcBytes = get_dtype_size(srcDtype) + dstBytes = get_dtype_size(dstDtype) + dtype_size = max(srcBytes, dstBytes) + if dstDtype == 'int8': + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 100): + print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") + return + elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 12): + print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") + return + + x0 = test_common.generate_tensor(shape, srcDtype) + torch_res = x0.to(eval("torch." + dstDtype)) + x0 = x0.npu() + + triton_res = torch.empty(shape, dtype=eval("torch." + dstDtype)).npu() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + grid = (1, ) + cast_to_multi_d[grid](triton_res, x0, *triton_shape) + test_common.validate_cmp(dstDtype, triton_res, torch_res) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 6, 2, 4, 2), + (2, 4, 2, 8, 4), + (3, 4, 2, 8, 4), +]) +@pytest.mark.parametrize('srcDtype', + ['int8', 'float16', 'float32'] +) +@pytest.mark.parametrize('dstDtype', + ['int8', 'float16', 'float32'] +) +def test_cast_5d(srcDtype, dstDtype, shape): + if srcDtype == dstDtype: + return + srcBytes = get_dtype_size(srcDtype) + dstBytes = get_dtype_size(dstDtype) + dtype_size = max(srcBytes, dstBytes) + if dstDtype == 'int8': + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 100): + print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") + return + elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 12): + print(f"srcDtype:{srcDtype}, dstDtype:{dstDtype}, shape:{shape} mem overflow") + return + + x0 = test_common.generate_tensor(shape, srcDtype) + torch_res = x0.to(eval("torch." + dstDtype)) + x0 = x0.npu() + + triton_res = torch.empty(shape, dtype=eval("torch." + dstDtype)).npu() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + grid = (1, ) + cast_to_multi_d[grid](triton_res, x0, *triton_shape) + test_common.validate_cmp(dstDtype, triton_res, torch_res) + + +if __name__ == "__main__": + for shape in [(3, ), (3, 3), (3, 3, 3)]: + for srcDtype in ['int8', 'float32', 'bool']: + for dstDtype in ['int8', 'float32', 'bool']: + test_cast(srcDtype, dstDtype, shape) diff --git a/third_party/ascend/examples/generalization_cases/test_cdiv.py b/third_party/ascend/examples/generalization_cases/test_cdiv.py new file mode 100644 index 000000000..47a9de7d5 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_cdiv.py @@ -0,0 +1,149 @@ +import logging + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common +from test_common import TestUtils +import math +import logging + + +def torch_cdiv(x0, x1, dtype): + return (x0 + x1 - 1) // x1 + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.cdiv(X, Y) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_cdiv_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = tl.cdiv(x_val, y_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['int8', 'int16', 'int32', 'int64']) +def test_case2(dtype, shape): + # 生成数据, cdiv int8 溢出的行为triton与torch_cpu不一致 + x = (test_common.generate_tensor(shape, dtype) // 2).abs().npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + y = (y.abs() // 2 + 1) + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_cdiv(x.cpu(), y.cpu(), eval('torch.' + dtype)) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if dtype == 'int8': + if x.numel() * x.element_size() >= 512: + grid = (1, 1, ZB) + ZB = 1 + else: + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) +def test_cdiv_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = (test_common.generate_tensor(shape, dtype) // 2).abs().npu() + y = test_common.generate_tensor(shape, dtype).npu() + y = (y.abs() // 2 + 1) + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_cdiv(x.cpu(), y.cpu(), eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_cdiv_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_ceil.py b/third_party/ascend/examples/generalization_cases/test_ceil.py new file mode 100644 index 000000000..9f1a51725 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_ceil.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common +from test_common import TestUtils +import logging +import math + + +def torch_ceil(x0): + res = torch.ceil(x0) + return res + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.ceil(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_ceil_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.ceil(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_ceil(x) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_ceil_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_ceil(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_ceil_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_common.py b/third_party/ascend/examples/generalization_cases/test_common.py new file mode 100644 index 000000000..bef24ec3c --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_common.py @@ -0,0 +1,264 @@ +import os +import re +import torch +import torch_npu +import math +import logging +from typing import AnyStr +import pytest +import functools + +log_level = os.getenv("LOG_LEVEL", "WARN").upper() +level_mapping = { + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "WARN": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL +} + +logging.basicConfig( + level=level_mapping.get(log_level, logging.WARNING), + format="[%(asctime)s][%(levelname)s] %(message)s" +) + +bisheng_not_support_dtypes = { + 'abs': [], + 'eq': [], + 'ne': [], + 'flip':['int64', 'bfloat16'], + 'load_store': ['int64'], + 'permute2d': ['int64'], + 'permute3d': ['int64'], + 'trans2d': ['int64'], + 'trans3d': ['int64'], + 'matmul': ['int16', 'int32', 'uint32', 'int64', 'bool'] +} + +tritonascend_not_support_dtypes = { + 'abs': ['bool'], + 'eq': ['bool'], + 'ne': ['bool'], + 'flip':['bool'], + 'load_store': ['bool'], + 'permute2d': ['bool'], + 'permute3d': ['bool'], + 'trans2d': ['bool'], + 'trans3d': ['bool'], +} + +def avoid_not_support(op: AnyStr): + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(shape, dtype, *args, **kwargs): + if dtype in bisheng_not_support_dtypes.get(op, []): + logging.warn(f'skiped bisheng not support dtype:{dtype}') + return + if dtype in tritonascend_not_support_dtypes.get(op, []): + logging.warn(f'skiped triton ascend not support dtype:{dtype}') + return + return test_func(shape, dtype, *args, **kwargs) + return wrapper + return decorator + + +def get_shape1d(in_shape1d): + result = [] + for i in in_shape1d: + v = tuple((i,)) + result.append(v) + return result + +def get_shape2d(in_shape1d, custom_shape): + result = [] + for a in in_shape1d: + for b in custom_shape: + t1 = tuple((a, b)) + t2 = tuple((b, a)) + if t1 not in result: + result.append(t1) + if t2 not in result: + result.append(t2) + return result + +def get_shape3d(): + return [(1,22,39),(27,1,39),(27,22,1),(23,1,1),(1,23,1),(1,1,23),(37,5,3),(2,29,4),(7,31,7),(3,5,8),(7,17,15),(23,5,16),(23,5,31),(7,11,32),(7,11,33),(2,3,255),(3,3,256),(3,2,257)] + +def get_shape1_2_3d(in_shape1d, custom_shape): + return get_shape1d(in_shape1d) + get_shape2d(in_shape1d, custom_shape) + get_shape3d() + +class TestUtils: + in_shape1d = [1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 37, 741] + custom_shape = [3, 13, 32, 256] + batch = [1, 2, 3, 4, 5, 8] + test_shape1d = get_shape1d(in_shape1d) + test_shape2d = get_shape2d(in_shape1d, custom_shape) + test_shape3d = [(1,22,39), (27,1,39), (27,22,1), (1,1,23), (23,1,1), (1,23,1), + (37,5,3), (2,29,4), (7,31,7), (3,5,8), (7,17,15), (25,5,16), + (23,5,31), (7,11,32), (7,11,33), (2,3,255), (3,3,256), (3,2,257),] + test_shape4d = [(8, 4, 8, 8), (1, 11, 16, 2)] + test_shape5d = [(2, 3, 4, 5, 6), (1, 3, 4, 5, 6), (3, 6, 2, 4, 4)] + test_shape6d = [(2, 3, 5, 6, 3, 2)] + test_shape7d = [(1, 2, 3, 4, 3, 2, 2)] + test_shape_ub_overflow = [(10, 50, 1000)] + test_shape8d = [(1, 2, 3, 2, 5, 3, 7, 2), (1, 3, 2, 5, 6, 7, 2, 1), (2, 3, 7, 3, 2, 3, 2, 3)] + full_shape_4_8d = test_shape4d + test_shape5d + test_shape6d + test_shape7d + test_shape8d + + full_shape = test_shape1d + test_shape2d + test_shape3d + test_shape1_2_3d = full_shape + full_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32', 'bool'] + ub_size = 98304 * 2 + dtype_list = full_dtype + +def get_dtype_size(dtype): + torch_dtype = eval('torch.' + dtype) + bits = 0 + if torch_dtype == torch.bool: + bits = 8 + elif torch.is_floating_point(torch.tensor(0, dtype=torch_dtype)): + bits = torch.finfo(torch_dtype).bits + else: + bits = torch.iinfo(torch_dtype).bits + return bits//8 + +def check_ub_mem_overflow(dtype, shape): + bytes = get_dtype_size(dtype) + if bytes * math.prod(shape) > TestUtils.ub_size: + logging.warning(f'dtype:{dtype} shape:{shape} mem overflow') + return True + return False + + +def generate_tensor(shape, dtype): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + return torch.randn(size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'uint32': + return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int8': + return torch.randint(low=0, high=127, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'bool': + return torch.randint(low=0, high=2, size=shape).bool() + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + + +def generate_tensor_int_withSigns(shape, dtype): + if dtype == 'int32' or dtype == 'int64' or dtype == 'int16': + return torch.randint(low=-32768, high=32767, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int8': + return torch.randint(low=-128, high=127, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'bool': + return torch.randint(low=0, high=2, size=shape).bool() + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + + +def get_triton_sig_typename(dtype): + if dtype == 'float32': + tyname = "*fp32" + elif dtype == 'int32': + tyname = "*i32" + elif dtype == 'int64': + tyname = "*i64" + elif dtype == 'float16': + tyname = "*fp16" + elif dtype == 'int16': + tyname = "*i16" + elif dtype == 'int8': + tyname = "*i8" + elif dtype == 'bool': + tyname = "*i1" + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + return tyname + +# Relative error: abs(x_ref - x_cal) / abs(x_ref) +# Absolute error: abs(x_ref - x_cal) + +# calculation type operators require different error range +# It is a stricter verification and not satisfied now, save it here +def validate_cal(dtype, y_cal, y_ref): + if dtype == 'float16': + if torch.mean(y_ref) < 0.001: + assert torch.abs(y_cal - y_ref) < 0.001, "|y_cal - y_ref| < 0.001 is required !" + else: + diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 + # all true + assert diff.all(), "Relative error is less than 0.001 !" + if dtype == 'float32': + if torch.mean(y_ref) < 0.0001: + assert torch.abs(y_cal - y_ref) < 0.0001, "|y_cal - y_ref| < 0.0001 is required !" + else: + diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.0001 + assert diff.all(), "Relative error is less than 0.001 !" + elif dtype == 'bfloat16': + diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 + assert diff.all(), "Relative error is less than 0.001 !" + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': + assert torch.equal(y_cal, y_ref) + elif dtype == 'bool': + assert torch.equal(y_cal, y_ref) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + + +# moving and comparison ops require no precision error +def validate_cmp(dtype, y_cal, y_ref): + y_cal = y_cal.npu() + y_ref = y_ref.npu() + if dtype == 'float16': + torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) + elif dtype == 'bfloat16': + torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=1e-03, atol=1e-03, + equal_nan=True) + elif dtype == 'float32': + torch.testing.assert_close(y_ref, y_cal, rtol=1e-04, atol=1e-04, equal_nan=True) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8' or dtype == 'uint32': + assert torch.equal(y_cal, y_ref) + elif dtype == 'bool': + assert torch.equal(y_cal.cpu(), y_ref.cpu()) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + + +def validate_cmp_with_expection(dtype, y_cal, y_ref, expect): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + if expect: + assert torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) + else: + assert not torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': + if expect: + assert torch.equal(y_cal, y_ref) + else: + assert not torch.equal(y_cal, y_ref) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + +def raises_with_match(expected_exception, match_pattern): + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + with pytest.raises(expected_exception, match=match_pattern): + return test_func(*args, **kwargs) + return wrapper + return decorator + +def capture_output(expected_output): + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + capsys = kwargs.pop('capsys', None) + if capsys is None: + try: + capsys = pytest.fixture(capsys)() + except: + raise RuntimeError("This decorator requires pytest's capsys fixture") + test_func(capsys, *args, **kwargs) + captured = capsys.readouterr() + # pybind11::scoped_ostream_redirect captures std::cout with \x00 inserted + # for now, no idea how to eliminate \x00 from C++ side. + cleaned = re.sub(r"\x00", "", captured.out) + assert expected_output in cleaned + return wrapper + return decorator diff --git a/third_party/ascend/examples/generalization_cases/test_cos.py b/third_party/ascend/examples/generalization_cases/test_cos.py new file mode 100644 index 000000000..2377ae417 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_cos.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common +from test_common import TestUtils +import logging +import math + + +def torch_cos(x0): + res = torch.cos(x0) + return res + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.cos(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_cos_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.cos(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_cos(x) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_cos_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_cos(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_cos_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_count_dim0.py b/third_party/ascend/examples/generalization_cases/test_count_dim0.py new file mode 100644 index 000000000..0b37f2987 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_count_dim0.py @@ -0,0 +1,121 @@ + # -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils + +def standard_count(x0, cmp_val, dim, dtype): + res = (x0 == cmp_val).sum(dim=dim) + return res + +def standard_count_gt(x0, cmp_val, dim, dtype): + res = (x0 > cmp_val).sum(dim=dim) + return res + +def standard_count_lt(x0, cmp_val, dim, dtype): + res = (x0 < cmp_val).sum(dim=dim) + return res + +@triton.jit +def count(in_ptr0, out_ptr0, cmp_val, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, N) + tl.program_id(2) * N + mmask = mblk_idx < MNUMEL + nmask = nblk_idx < NNUMEL + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=0) + tmp1 = (x == cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + nblk_idx, ret, mask = nmask) + +@triton.jit +def count_gt(in_ptr0, out_ptr0, cmp_val, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, N) + tl.program_id(2) * N + mmask = mblk_idx < MNUMEL + nmask = nblk_idx < NNUMEL + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=0) + tmp1 = (x > cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + nblk_idx, ret, mask = nmask) + +@triton.jit +def count_lt(in_ptr0, out_ptr0, cmp_val, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, N) + tl.program_id(2) * N + mmask = mblk_idx < MNUMEL + nmask = nblk_idx < NNUMEL + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=0) + tmp1 = (x < cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + nblk_idx, ret, mask = nmask) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['int8']) +def test_count_dim0_common(shape, dtype): + rblock = shape[1] + xblock = shape[0] + x0 = test_common.generate_tensor(shape, dtype).npu() + + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + + ans = standard_count(x0, cmp_val, 0, dtype) + + output = torch.zeros((shape[1],), dtype = torch.float32).npu() + count[1, 1, rblock](x0, output, cmp_val, 0, xblock, 1, xblock, rblock) + + test_common.validate_cmp("float32", output, ans.to(torch.float32)) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) +def test_count_gt_dim0_common(shape, dtype): + rblock = shape[1] + xblock = shape[0] + x0 = test_common.generate_tensor(shape, dtype).npu() + + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + + ans = standard_count_gt(x0, cmp_val,0, dtype) + + output = torch.zeros((shape[1],), dtype = torch.float32).npu() + count_gt[1, 1, rblock](x0, output, cmp_val, 0, xblock, 1, xblock, rblock) + + test_common.validate_cmp("float32", output, ans.to(torch.float32)) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) +def test_count_lt_dim0_common(shape, dtype): + rblock = shape[1] + xblock = shape[0] + x0 = test_common.generate_tensor(shape, dtype).npu() + + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + + ans = standard_count_lt(x0, cmp_val,0, dtype) + + output = torch.zeros((shape[1],), dtype = torch.float32).npu() + count_lt[1, 1, rblock](x0, output, cmp_val, 0, xblock, 1, xblock, rblock) + + test_common.validate_cmp("float32", output, ans.to(torch.float32)) diff --git a/third_party/ascend/examples/generalization_cases/test_count_dim1.py b/third_party/ascend/examples/generalization_cases/test_count_dim1.py new file mode 100644 index 000000000..951e15032 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_count_dim1.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils + + +def standard_count(x0, cmp_val, dim, dtype): + res = (x0 == cmp_val).sum(dim=dim) + return res + + +def standard_count_gt(x0, cmp_val, dim, dtype): + res = (x0 > cmp_val).sum(dim=dim) + return res + + +def standard_count_lt(x0, cmp_val, dim, dtype): + res = (x0 < cmp_val).sum(dim=dim) + return res + + +@triton.jit +def count(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, + NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, M) + tl.program_id(1) * M + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < MNUMEL + nmask = nblk_idx < NNUMEL + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=0) + tmp1 = (x == cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) + + +@triton.jit +def count_gt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, + NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, M) + tl.program_id(1) * M + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < MNUMEL + nmask = nblk_idx < NNUMEL + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=0) + tmp1 = (x > cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) + + +@triton.jit +def count_lt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, + NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, M) + tl.program_id(1) * M + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < MNUMEL + nmask = nblk_idx < NNUMEL + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * NNUMEL + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=0) + tmp1 = (x < cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['int8']) +def test_count_dim1_common(shape, dtype): + rblock = shape[1] + xblock = shape[0] + x0 = test_common.generate_tensor(shape, dtype).npu() + + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + + ans = standard_count(x0, cmp_val,1, dtype) + + output = torch.zeros((shape[0],), dtype = torch.float32).npu() + count[1, xblock, 1](x0, output, cmp_val, 1, 1, rblock, xblock, rblock) + + test_common.validate_cmp("float32", output, ans.to(torch.float32)) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) +def test_count_gt_dim1_common(shape, dtype): + rblock = shape[1] + xblock = shape[0] + x0 = test_common.generate_tensor(shape, dtype).npu() + + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + + ans = standard_count_gt(x0, cmp_val,1, dtype) + + output = torch.zeros((shape[0],), dtype = torch.float32).npu() + count_gt[1, xblock, 1](x0, output, cmp_val, 1, 1, rblock, xblock, rblock) + + test_common.validate_cmp("float32", output, ans.to(torch.float32)) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int8']) +def test_count_lt_dim1_common(shape, dtype): + rblock = shape[1] + xblock = shape[0] + x0 = test_common.generate_tensor(shape, dtype).npu() + + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + + ans = standard_count_lt(x0, cmp_val,1, dtype) + + output = torch.zeros((shape[0],), dtype = torch.float32).npu() + count_lt[1, xblock, 1](x0, output, cmp_val, 1, 1, rblock, xblock, rblock) + + test_common.validate_cmp("float32", output, ans.to(torch.float32)) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_cumprod.py b/third_party/ascend/examples/generalization_cases/test_cumprod.py new file mode 100644 index 000000000..6cdfe6859 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_cumprod.py @@ -0,0 +1,255 @@ +import math +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +from triton.runtime.libentry import libentry + +from test_common import TestUtils, validate_cmp, get_dtype_size + + +def torch_func(x, dim, reverse): + is_bf16 = x.dtype == torch.bfloat16 + if is_bf16: + x = x.to(torch.float32) + if reverse: + x = torch.flip(x, [dim]) + res = torch.cumprod(x, dim=dim) + if is_bf16: + res = res.to(torch.bfloat16) + return res + + +@libentry() +@triton.jit +def triton_kernel_1d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + XBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + idx = tl.arange(0, XBLOCK) + x = tl.load(in_ptr0 + idx) + ret = tl.cumprod(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +@libentry() +@triton.jit +def triton_kernel_2d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx = idx_x[:, None] * numel_r + idx_r[None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.cumprod(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +@libentry() +@triton.jit +def triton_kernel_3d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + numel_z: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, + ZBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + tl.static_assert( + numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx_z = tl.arange(0, ZBLOCK) + idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.cumprod(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +@libentry() +@triton.jit +def triton_kernel_4d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + XB: tl.constexpr, + YB: tl.constexpr, + ZB: tl.constexpr, + MB: tl.constexpr, +): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + idx = (xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + + zidx[None, None, :, None] * MB + midx[None, None, None, :]) + x = tl.load(in_ptr0 + idx) + ret = tl.cumprod(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +@libentry() +@triton.jit +def triton_kernel_5d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + XB: tl.constexpr, + YB: tl.constexpr, + ZB: tl.constexpr, + MB: tl.constexpr, + NB: tl.constexpr, +): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + nidx = tl.arange(0, NB) + idx = (xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + + zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + + nidx[None, None, None, None, :]) + x = tl.load(in_ptr0 + idx) + ret = tl.cumprod(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +def convert_cumprod_dtype(x: torch.Tensor) -> torch.Tensor: + """ + 根据 cumprod 类型转换规则,返回转换后的张量。 + """ + dtype_map = { + torch.int8: torch.int64, + torch.int16: torch.int64, + torch.int32: torch.int64, + torch.int64: torch.int64, + torch.bfloat16: torch.bfloat16, + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.bool: torch.int64, + } + + target_dtype = dtype_map.get(x.dtype, None) + if target_dtype is None: + raise ValueError(f"Unsupported input dtype for cumprod conversion: {x.dtype}") + + return x.to(target_dtype) + + +def triton_func(x, dim, reverse): + x = convert_cumprod_dtype(x) + + res = torch.empty_like(x) + shape = x.size() + if len(shape) == 1: + if dim >= 1: + pytest.skip("dim >= 1 for 1D tensor, skipping.") + triton_kernel_1d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[0] + ) + elif len(shape) == 2: + if dim >= 2: + pytest.skip("dim >= 2 for 2D tensor, skipping.") + triton_kernel_2d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1] + ) + elif len(shape) == 3: + if dim >= 3: + pytest.skip("dim >= 3 for 3D tensor, skipping.") + triton_kernel_3d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], x.shape[2] + ) + elif len(shape) == 4: + if dim >= 4: + pytest.skip("dim >= 4 for 4D tensor, skipping.") + triton_kernel_4d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3] + ) + elif len(shape) == 5: + if dim >= 5: + pytest.skip("dim >= 5 for 5D tensor, skipping.") + triton_kernel_5d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4] + ) + else: + pytest.skip(f"Unsupported tensor dimension: {len(shape)}") + + return res + + +def cumprod_generate_tensor(shape, dtype): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + return torch.rand(size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': + return torch.randint(low=1, high=5, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int8': + return torch.randint(low=1, high=5, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'bool': + return torch.randint(low=0, high=2, size=shape).bool() + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def should_skip_due_to_mem(dtype, shape): + dtype_size = get_dtype_size(dtype) + total_mem = dtype_size * math.prod(shape) + + if dtype in ('int8', 'bool'): + threshold = TestUtils.ub_size / 13 + else: + threshold = TestUtils.ub_size / 6 + + if total_mem >= threshold: + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + + +# reverse=True not support; +@pytest.mark.parametrize("dtype", TestUtils.full_dtype) +@pytest.mark.parametrize("shape", TestUtils.full_shape) +@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4]) +@pytest.mark.parametrize("reverse", [False]) +def test_cumprod(dtype, shape, dim, reverse): + should_skip_due_to_mem(dtype, shape) + + x = cumprod_generate_tensor(shape=shape, dtype=dtype) + x_npu = x.npu() + + triton_res = triton_func(x_npu, dim, reverse) + + x_gold = x + cpu_res = torch_func(x_gold, dim, reverse) + + validate_cmp(dtype, triton_res, cpu_res) diff --git a/third_party/ascend/examples/generalization_cases/test_cumsum.py b/third_party/ascend/examples/generalization_cases/test_cumsum.py new file mode 100644 index 000000000..2ceafef6d --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_cumsum.py @@ -0,0 +1,254 @@ +import math +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +from triton.runtime.libentry import libentry + +import acc_util +import test_common +from test_common import TestUtils, get_dtype_size + + +def torch_func(x, dim, reverse): + if reverse: + x = torch.flip(x, [dim]) + res = torch.cumsum(x, dim=dim) + return res + + +@libentry() +@triton.jit +def triton_kernel_1d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + XBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + idx = tl.arange(0, XBLOCK) + x = tl.load(in_ptr0 + idx) + ret = tl.cumsum(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +@libentry() +@triton.jit +def triton_kernel_2d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx = idx_x[:, None] * numel_r + idx_r[None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.cumsum(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +@libentry() +@triton.jit +def triton_kernel_3d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + numel_z: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, + ZBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + tl.static_assert( + numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx_z = tl.arange(0, ZBLOCK) + idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.cumsum(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +@libentry() +@triton.jit +def triton_kernel_4d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + XB: tl.constexpr, + YB: tl.constexpr, + ZB: tl.constexpr, + MB: tl.constexpr, +): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + idx = (xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + + zidx[None, None, :, None] * MB + midx[None, None, None, :]) + x = tl.load(in_ptr0 + idx) + ret = tl.cumsum(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +@libentry() +@triton.jit +def triton_kernel_5d( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + XB: tl.constexpr, + YB: tl.constexpr, + ZB: tl.constexpr, + MB: tl.constexpr, + NB: tl.constexpr, +): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + nidx = tl.arange(0, NB) + idx = (xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + + zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + + nidx[None, None, None, None, :]) + x = tl.load(in_ptr0 + idx) + ret = tl.cumsum(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +def convert_cumsum_dtype(x: torch.Tensor) -> torch.Tensor: + """ + 根据 cumsum 类型转换规则,返回转换后的张量。 + """ + dtype_map = { + torch.int8: torch.int64, + torch.int16: torch.int64, + torch.int32: torch.int64, + torch.int64: torch.int64, + torch.bfloat16: torch.bfloat16, + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.bool: torch.int64, + } + + target_dtype = dtype_map.get(x.dtype, None) + if target_dtype is None: + raise ValueError(f"Unsupported input dtype for cumsum conversion: {x.dtype}") + + return x.to(target_dtype) + + +def triton_func(x, dim, reverse): + x = convert_cumsum_dtype(x) + + res = torch.empty_like(x) + shape = x.size() + if len(shape) == 1: + if dim >= 1: + pytest.skip("dim >= 1 for 1D tensor, skipping.") + triton_kernel_1d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[0] + ) + elif len(shape) == 2: + if dim >= 2: + pytest.skip("dim >= 2 for 2D tensor, skipping.") + triton_kernel_2d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1] + ) + elif len(shape) == 3: + if dim >= 3: + pytest.skip("dim >= 3 for 3D tensor, skipping.") + triton_kernel_3d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], x.shape[2] + ) + elif len(shape) == 4: + if dim >= 4: + pytest.skip("dim >= 4 for 4D tensor, skipping.") + triton_kernel_4d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3] + ) + elif len(shape) == 5: + if dim >= 5: + pytest.skip("dim >= 5 for 5D tensor, skipping.") + triton_kernel_5d[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4] + ) + else: + pytest.skip(f"Unsupported tensor dimension: {len(shape)}") + + return res + + +def cumsum_generate_tensor(shape, dtype): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + return torch.rand(size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': + return torch.randint(low=0, high=3, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int8': + return torch.randint(low=0, high=3, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'bool': + return torch.randint(low=0, high=2, size=shape).bool() + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def should_skip_due_to_mem(dtype, shape): + dtype_size = get_dtype_size(dtype) + total_mem = dtype_size * math.prod(shape) + + if dtype in ('int8', 'bool'): + threshold = TestUtils.ub_size / 13 + else: + threshold = TestUtils.ub_size / 6 + + if total_mem >= threshold: + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + + +# reverse=True not support; + + +@pytest.mark.parametrize("dtype", TestUtils.full_dtype) +@pytest.mark.parametrize("shape", TestUtils.full_shape) +@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4]) +@pytest.mark.parametrize("reverse", [False]) +def test_cumsum(dtype, shape, dim, reverse): + should_skip_due_to_mem(dtype, shape) + + x = cumsum_generate_tensor(shape=shape, dtype=dtype) + x_npu = x.npu() + + triton_res = triton_func(x_npu, dim, reverse) + + x_gold = x + cpu_res = torch_func(x_gold, dim, reverse) + + test_common.validate_cmp(dtype, triton_res, cpu_res) diff --git a/third_party/ascend/examples/generalization_cases/test_debug_barrier.py b/third_party/ascend/examples/generalization_cases/test_debug_barrier.py new file mode 100644 index 000000000..a0a18ce2b --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_debug_barrier.py @@ -0,0 +1,154 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import logging +import pytest +import test_common +from test_common import TestUtils + + +def torch_invert(x0, ddtype): + if 'float' in str(ddtype): + x0 = x0.to(torch.int32) + y_ref = ~x0 + y_ref = y_ref.to(ddtype) + else: + y_ref = ~x0 + return y_ref + + +@triton.jit +def triton_sub(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X - Y + tl.debug_barrier() + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_invert_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = ~x_val + tl.debug_barrier() + tl.store(output_ptr + offsets, ret, mask=masks) + + +test_shape_1d_2d_3d = [(1,), (2,), (1, 1), (3, 13), (1, 1, 1), (4, 3, 8)] +test_shape_4_5d = [(1, 1, 1, 1), (2, 2, 2, 2), (1, 1, 1, 1, 1), (2, 2, 2, 2, 1)] + + +@pytest.mark.parametrize('shape', test_shape_1d_2d_3d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_sub(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x - y + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if dtype == 'int8': + if x.numel() * x.element_size() >= 512: + grid = (1, 1, ZB) + ZB = 1 + else: + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + triton_sub[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', test_shape_1d_2d_3d + test_shape_4_5d) +@pytest.mark.parametrize('dtype', ['bool']) +def test_invert_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_invert(x, eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_invert_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_device_print_op.py b/third_party/ascend/examples/generalization_cases/test_device_print_op.py new file mode 100644 index 000000000..5880285ad --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_device_print_op.py @@ -0,0 +1,133 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common + +import os +os.environ["TRITON_DEVICE_PRINT"] = "1" +os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" + +shape = (8,) +XS = 8 +XVALS_INT = [0, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + torch.iinfo(torch.int16).min, + torch.iinfo(torch.int16).max, + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + torch.iinfo(torch.int32).max+1] +XVALS_FP = [0, + torch.finfo(torch.float32).eps, + torch.finfo(torch.float16).eps, + torch.finfo(torch.bfloat16).eps, + torch.finfo(torch.float32).max, + torch.finfo(torch.float16).max, + torch.finfo(torch.bfloat16).max, + 1] + +def torch_func(x0, x1): + res = x0 + x1 + return res + +@triton.jit +def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr): + idx = tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 + tmp1 + tl.device_print("OUTPUT = ", tmp2) + tl.store(out_ptr0 + idx, tmp2) + +def triton_func(x0, x1, XS): + out = torch.empty_like(x0) + triton_kernel[1, 1, 1](out, x0, x1, XS) + return out + +@pytest.mark.skip(reason="waiting for bishengir-compile to support") +@pytest.mark.parametrize('sigtype', ['int64']) +@test_common.capture_output("???") +def test_device_print_int64(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + +@pytest.mark.parametrize('sigtype', ['int32']) +@test_common.capture_output("0,-128,127,-32768,32767,-2147483648,2147483647,-2147483648") +def test_device_print_int32(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + +@pytest.mark.parametrize('sigtype', ['int16']) +@test_common.capture_output("0,-128,127,-32768,32767,0,-1,0") +def test_device_print_int16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + +@pytest.mark.parametrize('sigtype', ['int8']) +@test_common.capture_output("0,-128,127,0,-1,0,-1,0") +def test_device_print_int8(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + +@pytest.mark.parametrize('sigtype', ['float32']) +@test_common.capture_output("0,1.19209e-07,0.000976562,0.0078125,3.40282e+38,65504,3.38953e+38,1") +def test_device_print_fp32(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_FP[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + +@pytest.mark.parametrize('sigtype', ['float16']) +@test_common.capture_output("0,1.19209e-07,0.000976562,0.0078125,inf,65504,inf,1") +def test_device_print_fp16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_FP[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + +@pytest.mark.skip(reason="waiting for bishengir-compile to support") +@pytest.mark.parametrize('sigtype', ['bfloat16']) +@test_common.capture_output("???") +def test_device_print_bf16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_FP[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_div_rn.py b/third_party/ascend/examples/generalization_cases/test_div_rn.py new file mode 100644 index 000000000..b8c6a8923 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_div_rn.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common +from test_common import TestUtils +import logging +import math + + +def torch_divRn(x0, x1): + return x0 / x1 + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.div_rn(X, Y) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_div_rn_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = tl.div_rn(x_val, y_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_divRn(x, y) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_div_rn_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + ans = torch_divRn(x, y) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_div_rn_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_dot_scaled.py b/third_party/ascend/examples/generalization_cases/test_dot_scaled.py new file mode 100644 index 000000000..e2c94a355 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_dot_scaled.py @@ -0,0 +1,285 @@ +import contextlib +import itertools +import re +import math +import textwrap +import os +import inspect +import pathlib +import test_common +import numpy as np +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +from numpy.random import RandomState +from triton.language.extra import libdevice + + +@triton.jit +def dot_scale_kernel(a_base, stride_a0: tl.constexpr, stride_a1: tl.constexpr, a_scale, b_base, stride_b0: tl.constexpr, + stride_b1: tl.constexpr, b_scale, out, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, + type_b: tl.constexpr, acc_num: tl.constexpr): + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K + str_a0: tl.constexpr = stride_a0 + a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, + str_a0)[None, :] * stride_a1 + b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, + BLOCK_N)[None, :] * stride_b1 + + a = tl.load(a_ptr) + b = tl.load(b_ptr) + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + if a_scale is not None: + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + a_scale = tl.load(scale_a_ptr) + if b_scale is not None: + scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + b_scale = tl.load(scale_b_ptr) + accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, out_dtype=tl.float32) + if acc_num is not None: + for _ in range(acc_num): + accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, + out_dtype=tl.float32) + + out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + tl.store(out_ptr, accumulator.to(a.dtype)) + + +def golden_ref(x, scale_x, y, scale_y): + shape_expand_x = x.shape[-1] // scale_x.shape[-1] + if x.dtype == torch.bfloat16: + upscale_x = scale_x.repeat_interleave(shape_expand_x, dim=1).to(torch.int16) + upscale_x = (upscale_x + 127 << 7).view(torch.bfloat16) + else: + scale_fp32 = scale_x.repeat_interleave(shape_expand_x, dim=1).to(torch.int32) + scale_fp32 = (scale_fp32 + 127 << 23).view(torch.float32) + upscale_x = scale_fp32.to(torch.float16) + upscale_y = None + if scale_y is None: + upscale_y = torch.ones_like(y) + else: + scale_y = scale_y.T + shape_expand_y = y.shape[0] // scale_y.shape[0] + if y.dtype == torch.bfloat16: + upscale_y = scale_y.repeat_interleave(shape_expand_y, dim=0).to(torch.int16) + upscale_y = (upscale_y + 127 << 7).view(torch.bfloat16) + else: + scale_fp32 = scale_y.repeat_interleave(shape_expand_y, dim=0).to(torch.int32) + scale_fp32 = (scale_fp32 + 127 << 23).view(torch.float32) + upscale_y = scale_fp32.to(torch.float16) + ret = torch.matmul(x * upscale_x, y * upscale_y) + return ret + + +@pytest.mark.parametrize("M, N, K, rhs_scale, normal_type, acc_num, num_warps", + [(M, N, K, rhs_scale, normal_type, acc_num, 4) + for M, N, K in itertools.product([16, 32, 64, 128], [16, 32, 64, 128], [32, 64]) + for rhs_scale in [False, True] + for normal_type in ["bf16", "fp16"] + for acc_num in [None, 1, 2]]) +def test_scaled_dot(M, N, K, rhs_scale, normal_type, num_warps, acc_num): + device = "npu" + + # The max exponent we use to initialize data in the x/y and associated scale tensor to avoid + # overflow when scaling. + comp_dtype_max_exp = 6 if normal_type == "fp16" else 15 + + torch.manual_seed(0) + + def make_arg(shape, ty): + if ty == "bf16" or ty == "fp16": + comp_dtype = torch.float16 if ty == "fp16" else torch.bfloat16 + ret = torch.randn(shape, dtype=comp_dtype, device=device) + # Clamp to avoid relative error issues + ret.clamp_(-2 ** comp_dtype_max_exp, 2 ** comp_dtype_max_exp - 1) + else: + ret = torch.randint(256, shape, dtype=torch.int8, device=device) + return ret + + type_a = normal_type + type_b = type_a + + x = make_arg((M, K), type_a) + y = make_arg((K, N), type_b) + + min_scale, max_scale = (0, 142) if type_a == torch.bfloat16 else (124, 131) + scale_x = torch.randint(min_scale - 128, max_scale - 127, (M, K // 32), dtype=torch.int8, device=device) + min_scale, max_scale = (0, 142) if type_b == torch.bfloat16 else (124, 131) + scale_y = torch.randint(min_scale - 128, max_scale - 127, (N, K // 32), dtype=torch.int8, device=device) + + if not rhs_scale: + scale_y = None + + kernel_kwargs = {"num_warps": num_warps} + z = x.new_empty((M, N), dtype=x.dtype) + pgm = dot_scale_kernel[(1,)](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, + acc_num, **kernel_kwargs) + z_ref = golden_ref(x, scale_x, y, scale_y) + if acc_num is not None: + z_ref = z_ref * (acc_num + 1) + + atol = 1e-5 + rtol = 1e-2 + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("B, M, N, K", [(1, 32, 64, 64)]) +def test_4d_dot(B, M, N, K): + device = "npu" + torch.manual_seed(0) + + x4d = torch.randn((B, B, M, N), dtype=torch.float16, device=device) + y4d = torch.randn((B, B, N, K), dtype=torch.float16, device=device) + + x2d = x4d.view(-1, N) # shape (B*B*M, N) + y2d = y4d.view(-1, K) # shape (B*B*N, K) + scale_x = torch.randint(-10, 10, (x2d.shape[0], N // 32), + dtype=torch.int8, device=device) + scale_y = torch.randint(-10, 10, (y2d.shape[1], N // 32), + dtype=torch.int8, device=device) + + z = torch.empty((x2d.shape[0], y2d.shape[0]), + dtype=x2d.dtype, device=device) + acc_num = None + dot_scale_kernel[(1,)]( + x2d, *x2d.stride(), scale_x, + y2d, *y2d.stride(), None, + z, + x2d.shape[0], y2d.shape[0], K, + "fp16", "fp16", None, + num_warps=4 + ) + z_ref = golden_ref(x2d, scale_x, y2d, None) + if acc_num is not None: + z_ref = z_ref * (acc_num + 1) + + atol = 1e-5 + rtol = 1e-2 + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("B, M, N, K", [(2, 16, 16, 32)]) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, + r"lhs last dimension .* must equal rhs penultimate dimension" + ) +def test_2d_dot_invaild_shape(B, M, N, K): + device = "npu" + torch.manual_seed(0) + + x4d = torch.randn((B, B, M, N), dtype=torch.float16, device=device) + y4d = torch.randn((B, B, N, K), dtype=torch.float16, device=device) + + x2d = x4d.view(-1, N) # shape (B*B*M, N) + y2d = y4d.view(-1, K) # shape (B*B*N, K) + scale_x = torch.randint(-10, 10, (x2d.shape[0], N // 32), + dtype=torch.int8, device=device) + scale_y = torch.randint(-10, 10, (y2d.shape[1], N // 32), + dtype=torch.int8, device=device) + + z = torch.empty((x2d.shape[0], y2d.shape[0]), + dtype=x2d.dtype, device=device) + acc_num = None + dot_scale_kernel[(1,)]( + x2d, *x2d.stride(), scale_x, + y2d, *y2d.stride(), None, + z, + x2d.shape[0], y2d.shape[0], K, + "fp16", "fp16", None, + num_warps=4 + ) + + +VALID_MAIN_DTYPES = { + torch.float16, # fp16 + torch.bfloat16, # bf16 +} + +ALL_DTYPES = { + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float32, # fp32 + torch.bool, +} +ILLEGAL_MAIN_DTYPES = ALL_DTYPES - VALID_MAIN_DTYPES + +ILLEGAL_SCALE_DTYPES = { + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.bfloat16, + torch.bool, +} + +from itertools import product + + +def is_legal_dtype(lhs_dtype, rhs_dtype, lhs_scale_dtype, rhs_scale_dtype): + return ( + lhs_dtype in VALID_MAIN_DTYPES and + rhs_dtype in VALID_MAIN_DTYPES and + lhs_scale_dtype is torch.int8 and + rhs_scale_dtype is torch.int8 + ) + + +illegal_cases = [] +for lhs, rhs, lhs_s, rhs_s in product( + VALID_MAIN_DTYPES | ILLEGAL_MAIN_DTYPES, + VALID_MAIN_DTYPES | ILLEGAL_MAIN_DTYPES, + {torch.int8} | ILLEGAL_SCALE_DTYPES, + {torch.int8} | ILLEGAL_SCALE_DTYPES, +): + + if not is_legal_dtype(lhs, rhs, lhs_s, rhs_s): + illegal_cases.append((lhs, rhs, lhs_s, rhs_s)) + +illegal_cases = sorted(set(illegal_cases), key=lambda t: tuple(str(i) for i in t)) + + +@pytest.mark.parametrize( + "lhs_dtype, rhs_dtype, lhs_scale_dtype, rhs_scale_dtype", + illegal_cases, +) +@test_common.raises_with_match(Exception, r"(?i)invalid|unsupported|dtype") +def test_invalid_dtype_should_fail(lhs_dtype, rhs_dtype, + lhs_scale_dtype, rhs_scale_dtype): + device = "npu" + M, N, K = 32, 32, 64 + num_warps = 4 + + def make_tensor(shape, dtype): + return torch.randn(shape, dtype=dtype, device=device) \ + if dtype.is_floating_point else \ + torch.randint(-10, 10, shape, dtype=dtype, device=device) + + def make_scale(shape, dtype): + return torch.randint(-10, 10, shape, dtype=dtype, device=device) + + x = make_tensor((M, K), lhs_dtype) + y = make_tensor((K, N), rhs_dtype) + lhs_scale = make_scale((M, K // 32), lhs_scale_dtype) + rhs_scale = make_scale((N, K // 32), rhs_scale_dtype) + z = torch.empty((M, N), dtype=lhs_dtype, device=device) + + dot_scale_kernel[(1,)]( + x, *x.stride(), lhs_scale, + y, *y.stride(), rhs_scale, + z, + M, N, K, + str(lhs_dtype).split('.')[-1], + str(rhs_dtype).split('.')[-1], + None, + num_warps=num_warps, + ) diff --git a/third_party/ascend/examples/generalization_cases/test_eq.py b/third_party/ascend/examples/generalization_cases/test_eq.py new file mode 100644 index 000000000..923105285 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_eq.py @@ -0,0 +1,113 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common +from test_common import TestUtils +import math +import logging + + +def torch_eq(x0, x1): + if x0.dtype != torch.uint32: + return x0 == x1 + else: + return x0.to(torch.float32) == x1.to(torch.float32) + + +@triton.jit +def triton_eq(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index, mask=x_index < N) + tmp1 = tl.load(in_ptr1 + x_index, mask=x_index < N) + tmp2 = tmp0 == tmp1 + tl.store(out_ptr0 + x_index, tmp2, mask=x_index < N) + + +@triton.jit +def triton_eq_4d_5d( + x_ptr, y_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val == y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +def test_eq(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + # 生成数据 + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + + numel = x0.numel() + ncore = 1 if numel <= 32 else 32 + xblock = math.ceil(numel / ncore) + xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) + + # torch结果 + torch_res = torch_eq(x0, x1).to(eval('torch.' + dtype)) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + N = triton_res.numel() + triton_eq[ncore, 1, 1](x0, x1, triton_res, N, xblock, xblock_sub) + # 比较结果 + torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) + triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) + cmp_dtype = dtype if dtype != 'uint32' else 'float32' + test_common.validate_cmp(cmp_dtype, triton_res, torch_res) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_eq_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_eq(x, y).to(eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_eq_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_erf.py b/third_party/ascend/examples/generalization_cases/test_erf.py new file mode 100644 index 000000000..bb0c19388 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_erf.py @@ -0,0 +1,141 @@ +import logging + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common +from test_common import TestUtils +import math + + +def torch_erf(x0): + res = torch.erf(x0) + return res + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.erf(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_erf_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.erf(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_erf(x) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_erf_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_erf(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_erf_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_exp.py b/third_party/ascend/examples/generalization_cases/test_exp.py new file mode 100644 index 000000000..e2dadc241 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_exp.py @@ -0,0 +1,142 @@ +import logging + +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_pointwise(x0): + res = torch.exp(x0) + return res + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.exp(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_exp_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.exp(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_pointwise(x) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_exp_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_pointwise(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_exp_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_exp2.py b/third_party/ascend/examples/generalization_cases/test_exp2.py new file mode 100644 index 000000000..5245bc609 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_exp2.py @@ -0,0 +1,140 @@ +import logging + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common +from test_common import TestUtils +import math + + +def torch_exp2(x0): + res = torch.pow(2, x0, out=None) + return res + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.exp2(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_exp2_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.exp2(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_exp2(x) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_exp2_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_exp2(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_exp2_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_expand_dims.py b/third_party/ascend/examples/generalization_cases/test_expand_dims.py new file mode 100644 index 000000000..6eaaf9baa --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_expand_dims.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def fn_npu_1d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): + yidx = tl.arange(0, YB) + + X = tl.load(x_ptr + yidx) + + ret = tl.expand_dims(X, 1) + + oidx = yidx[:, None] + tl.arange(0, 1)[None, :] + + tl.store(output_ptr + oidx, ret) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_expand_dims_1d(shape, dtype): + x = test_common.generate_tensor(shape,dtype).npu() + a = x.unsqueeze(1) + + output = torch.randint(1, (shape[0], 1), dtype=eval('torch.' + dtype)).npu() + + fn_npu_1d[1, 1, 1](output, x, YB=shape[0], ZB=1, debug=True) + + torch.testing.assert_close(output, a) + + +@triton.jit +def fn_npu_2d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): + yoffs = tl.program_id(0) + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + + idx = yidx[:, None] * ZB + zidx[None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.expand_dims(X, 1) + + oidx = yidx[:, None, None] * ZB + tl.arange(0, 1)[None, :, None] + zidx[None, None, :] + + tl.store(output_ptr + oidx, ret) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_expand_dims_2d(shape, dtype): + x = test_common.generate_tensor(shape,dtype).npu() + a = x.unsqueeze(1) + + output = torch.randint(1, (shape[0], 1, shape[1]), dtype=eval('torch.' + dtype)).npu() + + if x.numel()*x.element_size()>8192: + fn_npu_2d[shape[0],1 ,1](output, x, YB=1, ZB=shape[1]) + else: + fn_npu_2d[1, 1, 1](output, x, YB=shape[0], ZB=shape[1]) + + torch.testing.assert_close(output, a) + + +@triton.jit +def fn_npu_3d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.expand_dims(X, 2) + + oidx = xidx[:, None, None, None] * YB * ZB + yidx[None, :, None, None] * ZB + tl.arange(0, 1)[None, None, :, + None] + zidx[None, None, None, :] + + tl.store(output_ptr + oidx, ret) + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +def test_expand_dims_3d(dtype, shape): + x = test_common.generate_tensor(shape,dtype).npu() + a = x.unsqueeze(2) + + output = torch.randint(1, (shape[0], shape[1], 1, shape[2]), dtype=eval('torch.' + dtype)).npu() + + fn_npu_3d[1, 1, 1](output, x, XB=shape[0], YB=shape[1], ZB=shape[2]) + + torch.testing.assert_close(output, a) + + +@triton.jit +def fn_npu_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr): + in_offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if DIMS > 1: + in_offsets = in_offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if DIMS > 2: + in_offsets = in_offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if DIMS > 3: + in_offsets = in_offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if DIMS > 4: + in_offsets = in_offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + X = tl.load(x_ptr + in_offsets) + + ret = tl.expand_dims(X, DIM).reshape(XB * YB * ZB * MB * NB) + + out_offstes = tl.arange(0, XB * YB * ZB * MB * NB) + tl.store(output_ptr + out_offstes, ret) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('dtype', ['int8', 'float16', 'float32']) +@pytest.mark.parametrize('shape', [ + (2, 64, 16, 2), + (8, 8, 4, 2), + (8, 8, 4, 1), +]) +@pytest.mark.parametrize('dim', [-1, 0, 1, 2, 3]) +def test_npu_4d(shape, dtype, dim): + x = test_common.generate_tensor(shape, dtype).npu() + expected = x.unsqueeze(dim) + + output = torch.empty_like(expected) + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + grid = (1, ) + fn_npu_multi_d[grid](output, x, *triton_shape, len(shape), dim) + + torch.testing.assert_close(output, expected) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('dtype', ['int8', 'float16', 'float32']) +@pytest.mark.parametrize('shape', [ + (2, 32, 3, 16, 2), + (8, 8, 3, 4, 2), + (8, 8, 3, 4, 1), +]) +@pytest.mark.parametrize('dim', [-1, 0, 1, 2, 3, 4]) +def test_npu_5d(shape, dtype, dim): + x = test_common.generate_tensor(shape, dtype).npu() + expected = x.unsqueeze(dim) + + output = torch.empty_like(expected) + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + grid = (1, ) + fn_npu_multi_d[grid](output, x, *triton_shape, len(shape), dim) + + torch.testing.assert_close(output, expected) diff --git a/third_party/ascend/examples/generalization_cases/test_fdiv.py b/third_party/ascend/examples/generalization_cases/test_fdiv.py new file mode 100644 index 000000000..dbea60d5f --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_fdiv.py @@ -0,0 +1,142 @@ +import logging + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common +from test_common import TestUtils +import math + + +def torch_fdiv(x0, x1): + res = x0 / x1 + return res + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.fdiv(X, Y) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_fdiv_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = tl.fdiv(x_val, y_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_fdiv(x, y) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_fdiv_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_fdiv(x, y) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_fdiv_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_flip.py b/third_party/ascend/examples/generalization_cases/test_flip.py new file mode 100644 index 000000000..7dbfd559d --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_flip.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common +import logging +from test_common import TestUtils, check_ub_mem_overflow +import triton.language.extra.ascend.libdevice as libdevice + +@triton.jit +def fn_npu_1d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + idx = xidx + X = tl.load(x_ptr + idx) + ret = libdevice.flip(X, 0) + oidx = xidx + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_2d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + tl.program_id(0) * XB + yidx = tl.arange(0, YB) + tl.program_id(1) * YB + idx = xidx[:, None] * YB + yidx[None, :] + X = tl.load(x_ptr + idx) + ret = libdevice.flip(X, 1) + oidx = xidx[:, None] * YB + yidx[None, :] + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_3d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xidx = tl.arange(0, XB) + tl.program_id(0) * XB + yidx = tl.arange(0, YB) + tl.program_id(1) * YB + zidx = tl.arange(0, ZB) + tl.program_id(2) * ZB + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + X = tl.load(x_ptr + idx) + ret = libdevice.flip(X, 2) + oidx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + +typelist = ['int8','int16','int32','int64','float16','bfloat16','float32', 'bool'] +#typelist = ['int64', 'bool', 'bfloat16'] # error dtypes + +dtype_mapping = { + 'int8': (torch.int8), + 'int16': (torch.int16), + 'int32': (torch.int32), + 'uint32': (torch.uint32), + 'int64': (torch.int64), + 'float16': (torch.float16), + 'float32': (torch.float32), + 'bfloat16': (torch.bfloat16), + 'bool': (torch.bool), +} + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('dtype',typelist) +def test_flip(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + data_dtype = eval('torch.' + dtype) + x = None + if dtype == 'bool': + x = torch.randint(low=0, high=2, size=shape, dtype=data_dtype).npu() + else: + x = torch.randint(low=0, high=128, size=shape, dtype=data_dtype).npu() + + torch_input = x if x.dtype != torch.uint32 else x.to(torch.float32) + torch_res = torch.flip(torch_input, dims=(-1,)) + triton_res = torch.empty(shape, dtype=data_dtype).npu() + + if len(shape) == 1: + fn_npu_1d[1, 1, 1](triton_res, x, shape[0], 1, 1) + elif len(shape) == 2: + fn_npu_2d[shape[0], 1, 1](triton_res, x, 1, shape[1], 1) + elif len(shape) == 3: + if shape[0] > shape[1]: + fn_npu_3d[shape[0], 1, 1](triton_res, x, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + else: + fn_npu_3d[1, shape[1], 1](triton_res, x, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + + triton_res = triton_res if triton_res.dtype != torch.uint32 else triton_res.to(torch.float32) + cmp_dtype = dtype if dtype != 'uint32' else 'float32' + test_common.validate_cmp(cmp_dtype, triton_res, torch_res) + + +@triton.jit +def fn_npu_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIMS: tl.constexpr, AXIS: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if DIMS > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if DIMS > 2: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if DIMS > 3: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if DIMS > 4: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + X = tl.load(x_ptr + offsets) + ret = libdevice.flip(X, AXIS) + tl.store(output_ptr + offsets, ret) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (4, 2, 8, 4), + (2, 4, 2, 8, 4), + + (4, 3, 8, 1), + (3, 4, 2, 8, 1), +]) +@pytest.mark.parametrize('dtype', typelist) +def test_flip_4d_5d(shape, dtype): + data_dtype = eval('torch.' + dtype) + x = None + if dtype == 'bool': + x = torch.randint(low=0, high=2, size=shape, dtype=data_dtype).npu() + else: + x = torch.randint(low=0, high=128, size=shape, dtype=data_dtype).npu() + + torch_input = x if x.dtype != torch.uint32 else x.to(torch.float32) + torch_res = torch.flip(torch_input, dims=(-1,)) + triton_res = torch.empty(shape, dtype=data_dtype).npu() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + grid = (1, ) + fn_npu_multi_d[grid](triton_res, x, *triton_shape, len(shape), len(shape) - 1) + + triton_res = triton_res if triton_res.dtype != torch.uint32 else triton_res.to(torch.float32) + cmp_dtype = dtype if dtype != 'uint32' else 'float32' + test_common.validate_cmp(cmp_dtype, triton_res, torch_res) + + +if __name__ == "__main__": + for dtype in TestUtils.dtype_list: + for shape in [(37, 3), (1, 22, 39)]: + test_flip(shape, dtype) diff --git a/third_party/ascend/examples/generalization_cases/test_full_op.py b/third_party/ascend/examples/generalization_cases/test_full_op.py new file mode 100644 index 000000000..01002ef4a --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_full_op.py @@ -0,0 +1,1265 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import triton +import triton.language as tl +import test_common + +from test_common import TestUtils +import torch +import torch_npu +import pytest +import math +import random + +@triton.jit +def fn_npu_int8_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): + xidx = tl.arange(0, X) + yidx = tl.arange(0, Y) + zidx = tl.arange(0, Z) + ret = tl.full((X, Y, Z), value=100, dtype=tl.int8) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): + xidx = tl.arange(0, X) + yidx = tl.arange(0, Y) + zidx = tl.arange(0, Z) + ret = tl.full((X, Y, Z), value=100, dtype=tl.int16) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_uint32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): + xidx = tl.arange(0, X) + yidx = tl.arange(0, Y) + zidx = tl.arange(0, Z) + ret = tl.full((X, Y, Z), value=100, dtype=tl.uint32) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): + xidx = tl.arange(0, X) + yidx = tl.arange(0, Y) + zidx = tl.arange(0, Z) + ret = tl.full((X, Y, Z), value=100, dtype=tl.int32) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int64_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): + xidx = tl.arange(0, X) + yidx = tl.arange(0, Y) + zidx = tl.arange(0, Z) + ret = tl.full((X, Y, Z), value=100, dtype=tl.int64) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_fp16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): + xidx = tl.arange(0, X) + yidx = tl.arange(0, Y) + zidx = tl.arange(0, Z) + ret = tl.full((X, Y, Z), value=100, dtype=tl.float16) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_fp32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): + xidx = tl.arange(0, X) + yidx = tl.arange(0, Y) + zidx = tl.arange(0, Z) + ret = tl.full((X, Y, Z), value=100, dtype=tl.float32) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_bf16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): + xidx = tl.arange(0, X) + yidx = tl.arange(0, Y) + zidx = tl.arange(0, Z) + ret = tl.full((X, Y, Z), value=100, dtype=tl.bfloat16) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_bool_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr): + xidx = tl.arange(0, X) + yidx = tl.arange(0, Y) + zidx = tl.arange(0, Z) + ret = tl.full((X, Y, Z), value=0, dtype=tl.int1) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int8_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): + yoffs = tl.program_id(0) * Y + yidx = tl.arange(0, Y) + yoffs + zidx = tl.arange(0, Z) + ret = tl.full((Y, Z), value=100, dtype=tl.int8) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): + yoffs = tl.program_id(0) * Y + yidx = tl.arange(0, Y) + yoffs + zidx = tl.arange(0, Z) + ret = tl.full((Y, Z), value=100, dtype=tl.int16) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_uint32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): + yoffs = tl.program_id(0) * Y + yidx = tl.arange(0, Y) + yoffs + zidx = tl.arange(0, Z) + ret = tl.full((Y, Z), value=100, dtype=tl.uint32) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): + yoffs = tl.program_id(0) * Y + yidx = tl.arange(0, Y) + yoffs + zidx = tl.arange(0, Z) + ret = tl.full((Y, Z), value=100, dtype=tl.int32) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int64_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): + yoffs = tl.program_id(0) * Y + yidx = tl.arange(0, Y) + yoffs + zidx = tl.arange(0, Z) + ret = tl.full((Y, Z), value=100, dtype=tl.int64) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_fp16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): + yoffs = tl.program_id(0) * Y + yidx = tl.arange(0, Y) + yoffs + zidx = tl.arange(0, Z) + ret = tl.full((Y, Z), value=100, dtype=tl.float16) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_fp32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): + yoffs = tl.program_id(0) * Y + yidx = tl.arange(0, Y) + yoffs + zidx = tl.arange(0, Z) + ret = tl.full((Y, Z), value=100, dtype=tl.float32) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_bf16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): + yoffs = tl.program_id(0) * Y + yidx = tl.arange(0, Y) + yoffs + zidx = tl.arange(0, Z) + ret = tl.full((Y, Z), value=100, dtype=tl.bfloat16) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_bool_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr): + yoffs = tl.program_id(0) * Y + yidx = tl.arange(0, Y) + yoffs + zidx = tl.arange(0, Z) + ret = tl.full((Y, Z), value=0, dtype=tl.int1) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int8_1d(output_ptr, Z: tl.constexpr): + zidx = tl.arange(0, Z) + ret = tl.full((Z,), value=100, dtype=tl.int8) + oidx = zidx + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int16_1d(output_ptr, Z: tl.constexpr): + zidx = tl.arange(0, Z) + ret = tl.full((Z,), value=100, dtype=tl.int16) + oidx = zidx + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_uint32_1d(output_ptr, Z: tl.constexpr): + zidx = tl.arange(0, Z) + ret = tl.full((Z,), value=100, dtype=tl.uint32) + oidx = zidx + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int32_1d(output_ptr, Z: tl.constexpr): + zidx = tl.arange(0, Z) + ret = tl.full((Z,), value=100, dtype=tl.int32) + oidx = zidx + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int64_1d(output_ptr, Z: tl.constexpr): + zidx = tl.arange(0, Z) + ret = tl.full((Z,), value=100, dtype=tl.int64) + oidx = zidx + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_fp16_1d(output_ptr, Z: tl.constexpr): + zidx = tl.arange(0, Z) + ret = tl.full((Z,), value=100, dtype=tl.float16) + oidx = zidx + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_fp32_1d(output_ptr, Z: tl.constexpr): + zidx = tl.arange(0, Z) + ret = tl.full((Z,), value=100, dtype=tl.float32) + oidx = zidx + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_bf16_1d(output_ptr, Z: tl.constexpr): + zidx = tl.arange(0, Z) + ret = tl.full((Z,), value=100, dtype=tl.bfloat16) + oidx = zidx + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_bool_1d(output_ptr, Z: tl.constexpr): + zidx = tl.arange(0, Z) + ret = tl.full((Z,), value=0, dtype=tl.int1) + oidx = zidx + tl.store(output_ptr + oidx, ret) + + +test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] +test_shape1d = TestUtils.test_shape1d +test_shape2d = TestUtils.test_shape2d +test_shape3d = TestUtils.test_shape3d + +# 定义 dtype 到 (test_func, test_sigtype) 的映射 +dtype_mapping3d = { + 'int8': (fn_npu_int8_3d, torch.int8), + 'int16': (fn_npu_int16_3d, torch.int16), + 'int32': (fn_npu_int32_3d, torch.int32), + 'uint32': (fn_npu_uint32_3d, torch.uint32), + 'int64': (fn_npu_int64_3d, torch.int64), + 'float16': (fn_npu_fp16_3d, torch.float16), + 'float32': (fn_npu_fp32_3d, torch.float32), + 'bfloat16': (fn_npu_bf16_3d, torch.bfloat16), + 'bool': (fn_npu_bool_3d, torch.bool), +} +dtype_mapping2d = { + 'int8': (fn_npu_int8_2d, torch.int8), + 'int16': (fn_npu_int16_2d, torch.int16), + 'int32': (fn_npu_int32_2d, torch.int32), + 'uint32': (fn_npu_uint32_2d, torch.uint32), + 'int64': (fn_npu_int64_2d, torch.int64), + 'float16': (fn_npu_fp16_2d, torch.float16), + 'float32': (fn_npu_fp32_2d, torch.float32), + 'bfloat16': (fn_npu_bf16_2d, torch.bfloat16), + 'bool': (fn_npu_bool_2d, torch.bool), +} +dtype_mapping1d = { + 'int8': (fn_npu_int8_1d, torch.int8), + 'int16': (fn_npu_int16_1d, torch.int16), + 'int32': (fn_npu_int32_1d, torch.int32), + 'uint32': (fn_npu_uint32_1d, torch.uint32), + 'int64': (fn_npu_int64_1d, torch.int64), + 'float16': (fn_npu_fp16_1d, torch.float16), + 'float32': (fn_npu_fp32_1d, torch.float32), + 'bfloat16': (fn_npu_bf16_1d, torch.bfloat16), + 'bool': (fn_npu_bool_1d, torch.bool), +} + +# 生成测试用例 +testlist = [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape1d + for func, dtype in [dtype_mapping1d[sigtype]] # 直接解包映射结果 +] + +testlist += [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape2d + for func, dtype in [dtype_mapping2d[sigtype]] # 直接解包映射结果 +] + +testlist += [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape3d + for func, dtype in [dtype_mapping3d[sigtype]] # 直接解包映射结果 +] + +@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist) +def test_npu(testfunc, sigtype, dtype, shape): + x = 0 + output = 0 + if len(shape) == 3: + if dtype == torch.bool: + x = torch.full((shape[0], shape[1], shape[2]), 0, dtype=dtype).npu() + else: + x = torch.full((shape[0], shape[1], shape[2]), 100, dtype=dtype).npu() + output = torch.randint(1, (shape[0], shape[1], shape[2]), dtype=dtype).npu() + testfunc[(1, 1, 1)](output, shape[0], shape[1], shape[2], debug=True) + if len(shape) == 2: + if dtype == torch.bool: + x = torch.full((shape[0], shape[1]), 0, dtype=dtype).npu() + else: + x = torch.full((shape[0], shape[1]), 100, dtype=dtype).npu() + output = torch.randint(1, (shape[0], shape[1]), dtype=dtype).npu() + shape0 = shape[0] + shape1 = shape[1] + if x.numel() * x.element_size() >= 8192: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + testfunc[grid](output, shape0, shape1, debug=True) + if len(shape) == 1: + if dtype == torch.bool: + x = torch.full((shape[0], ), 0, dtype=dtype).npu() + else: + x = torch.full((shape[0], ), 100, dtype=dtype).npu() + output = torch.randint(1, (shape[0],), dtype=dtype).npu() + testfunc[1, 1, 1](output, shape[0], debug=True) + test_common.validate_cmp(sigtype, output, x) + + +@triton.jit +def fn_npu_multi_d(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr): + dtype = output_ptr.type.element_ty + + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + if (YB * ZB * MB * NB) == 1: + ret = tl.full((XB, ), value=100, dtype=dtype) + elif (ZB * MB * NB) == 1: + ret = tl.full((XB, YB), value=100, dtype=dtype) + elif (MB * NB) == 1: + ret = tl.full((XB, YB, ZB), value=100, dtype=dtype) + elif NB == 1: + ret = tl.full((XB, YB, ZB, MB), value=100, dtype=dtype) + else: + ret = tl.full((XB, YB, ZB, MB, NB), value=100, dtype=dtype) + + tl.store(output_ptr + offsets, ret) + + +testlist_multi_d = [ + (fn_npu_multi_d, 'float32', torch.float32, (4, 2, 16, 16)), + (fn_npu_multi_d, 'float32', torch.float32, (2, 4, 2, 16, 16)), + + (fn_npu_multi_d, 'float32', torch.float16, (4, 2, 16, 16)), + (fn_npu_multi_d, 'float32', torch.float16, (2, 4, 2, 16, 16)), + + (fn_npu_multi_d, 'float32', torch.int8, (4, 2, 16, 16)), + (fn_npu_multi_d, 'float32', torch.int8, (2, 4, 2, 16, 16)), +] + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist_multi_d) +def test_npu_4d_5d(testfunc, sigtype, dtype, shape): + x = torch.full(shape, 100, dtype=dtype).npu() + + print(f"shape = {x.shape}") + print(x.dtype) + print(torch.flatten(x)[0:16]) + + output = torch.randint(1, shape, dtype=dtype).npu() + + print(f"output.dtype={output.dtype}") + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + testfunc[(1,)](output, *triton_shape) + print(torch.flatten(output)[0:16]) + + test_common.validate_cmp(sigtype, output, x) + + +@triton.jit +def fn_npu_bf16_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + + ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.bfloat16) + + oidx = (aidx[:, None, None, None, None, None] * + B * C * D * E * F + + bidx[None, :, None, None, None, None] * + C * D * E * F + + cidx[None, None, :, None, None, None] * + D * E * F + + didx[None, None, None, :, None, None] * + E * F + + eidx[None, None, None, None, :, None] * + F + + fidx[None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int8_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + + ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int8) + + oidx = (aidx[:, None, None, None, None, None] * + B * C * D * E * F + + bidx[None, :, None, None, None, None] * + C * D * E * F + + cidx[None, None, :, None, None, None] * + D * E * F + + didx[None, None, None, :, None, None] * + E * F + + eidx[None, None, None, None, :, None] * + F + + fidx[None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int16_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + + ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int16) + + oidx = (aidx[:, None, None, None, None, None] * + B * C * D * E * F + + bidx[None, :, None, None, None, None] * + C * D * E * F + + cidx[None, None, :, None, None, None] * + D * E * F + + didx[None, None, None, :, None, None] * + E * F + + eidx[None, None, None, None, :, None] * + F + + fidx[None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int32_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + + ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int32) + + oidx = (aidx[:, None, None, None, None, None] * + B * C * D * E * F + + bidx[None, :, None, None, None, None] * + C * D * E * F + + cidx[None, None, :, None, None, None] * + D * E * F + + didx[None, None, None, :, None, None] * + E * F + + eidx[None, None, None, None, :, None] * + F + + fidx[None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int64_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + + ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.int64) + + oidx = (aidx[:, None, None, None, None, None] * + B * C * D * E * F + + bidx[None, :, None, None, None, None] * + C * D * E * F + + cidx[None, None, :, None, None, None] * + D * E * F + + didx[None, None, None, :, None, None] * + E * F + + eidx[None, None, None, None, :, None] * + F + + fidx[None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_fp16_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + + ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.float16) + + oidx = (aidx[:, None, None, None, None, None] * + B * C * D * E * F + + bidx[None, :, None, None, None, None] * + C * D * E * F + + cidx[None, None, :, None, None, None] * + D * E * F + + didx[None, None, None, :, None, None] * + E * F + + eidx[None, None, None, None, :, None] * + F + + fidx[None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_fp32_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + + ret = tl.full((A, B, C, D, E, F), value=10, dtype=tl.float32) + + oidx = (aidx[:, None, None, None, None, None] * + B * C * D * E * F + + bidx[None, :, None, None, None, None] * + C * D * E * F + + cidx[None, None, :, None, None, None] * + D * E * F + + didx[None, None, None, :, None, None] * + E * F + + eidx[None, None, None, None, :, None] * + F + + fidx[None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_bool_6d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + + ret = tl.full((A, B, C, D, E, F), value=0, dtype=tl.int1) + + oidx = (aidx[:, None, None, None, None, None] * + B * C * D * E * F + + bidx[None, :, None, None, None, None] * + C * D * E * F + + cidx[None, None, :, None, None, None] * + D * E * F + + didx[None, None, None, :, None, None] * + E * F + + eidx[None, None, None, None, :, None] * + F + + fidx[None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] +test_shape6d = TestUtils.test_shape6d +dtype_mapping6d = { + 'int8': (fn_npu_int8_6d, torch.int8), + 'int16': (fn_npu_int16_6d, torch.int16), + 'int32': (fn_npu_int32_6d, torch.int32), + 'int64': (fn_npu_int64_6d, torch.int64), + 'float16': (fn_npu_fp16_6d, torch.float16), + 'float32': (fn_npu_fp32_6d, torch.float32), + 'bfloat16': (fn_npu_bf16_6d, torch.bfloat16), + 'bool': (fn_npu_bool_6d, torch.bool), +} + +testlist6d = [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape6d + for func, dtype in [dtype_mapping6d[sigtype]] +] + + +@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist6d) +def test_npu_6d(testfunc, sigtype, dtype, shape): + x = 0 + output = 0 + if len(shape) == 6: + if dtype == torch.bool: + x = torch.full((shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]), 0, dtype=dtype).npu() + else: + x = torch.full(shape, 10, dtype=dtype).npu() + output = torch.randint(1, shape, dtype=dtype).npu() + testfunc[1, 1, 1](output, *shape, debug=True) + test_common.validate_cmp(sigtype, output, x) + + +@triton.jit +def fn_npu_bf16_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + + ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.bfloat16) + + oidx = (aidx[:, None, None, None, None, None, None] * + B * C * D * E * F * G + + bidx[None, :, None, None, None, None, None] * + C * D * E * F * G + + cidx[None, None, :, None, None, None, None] * + D * E * F * G + + didx[None, None, None, :, None, None, None] * + E * F * G + + eidx[None, None, None, None, :, None, None] * + F * G + + fidx[None, None, None, None, None, :, None] * + G + + gidx[None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_int8_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + + ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int8) + + oidx = (aidx[:, None, None, None, None, None, None] * + B * C * D * E * F * G + + bidx[None, :, None, None, None, None, None] * + C * D * E * F * G + + cidx[None, None, :, None, None, None, None] * + D * E * F * G + + didx[None, None, None, :, None, None, None] * + E * F * G + + eidx[None, None, None, None, :, None, None] * + F * G + + fidx[None, None, None, None, None, :, None] * + G + + gidx[None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_int16_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + + ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int16) + + oidx = (aidx[:, None, None, None, None, None, None] * + B * C * D * E * F * G + + bidx[None, :, None, None, None, None, None] * + C * D * E * F * G + + cidx[None, None, :, None, None, None, None] * + D * E * F * G + + didx[None, None, None, :, None, None, None] * + E * F * G + + eidx[None, None, None, None, :, None, None] * + F * G + + fidx[None, None, None, None, None, :, None] * + G + + gidx[None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int32_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + + ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int32) + + oidx = (aidx[:, None, None, None, None, None, None] * + B * C * D * E * F * G + + bidx[None, :, None, None, None, None, None] * + C * D * E * F * G + + cidx[None, None, :, None, None, None, None] * + D * E * F * G + + didx[None, None, None, :, None, None, None] * + E * F * G + + eidx[None, None, None, None, :, None, None] * + F * G + + fidx[None, None, None, None, None, :, None] * + G + + gidx[None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int64_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + + ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.int64) + + oidx = (aidx[:, None, None, None, None, None, None] * + B * C * D * E * F * G + + bidx[None, :, None, None, None, None, None] * + C * D * E * F * G + + cidx[None, None, :, None, None, None, None] * + D * E * F * G + + didx[None, None, None, :, None, None, None] * + E * F * G + + eidx[None, None, None, None, :, None, None] * + F * G + + fidx[None, None, None, None, None, :, None] * + G + + gidx[None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_fp16_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + + ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.float16) + + oidx = (aidx[:, None, None, None, None, None, None] * + B * C * D * E * F * G + + bidx[None, :, None, None, None, None, None] * + C * D * E * F * G + + cidx[None, None, :, None, None, None, None] * + D * E * F * G + + didx[None, None, None, :, None, None, None] * + E * F * G + + eidx[None, None, None, None, :, None, None] * + F * G + + fidx[None, None, None, None, None, :, None] * + G + + gidx[None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_fp32_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + + ret = tl.full((A, B, C, D, E, F, G), value=10, dtype=tl.float32) + + oidx = (aidx[:, None, None, None, None, None, None] * + B * C * D * E * F * G + + bidx[None, :, None, None, None, None, None] * + C * D * E * F * G + + cidx[None, None, :, None, None, None, None] * + D * E * F * G + + didx[None, None, None, :, None, None, None] * + E * F * G + + eidx[None, None, None, None, :, None, None] * + F * G + + fidx[None, None, None, None, None, :, None] * + G + + gidx[None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_bool_7d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + + ret = tl.full((A, B, C, D, E, F, G), value=0, dtype=tl.int1) + + oidx = (aidx[:, None, None, None, None, None, None] * + B * C * D * E * F * G + + bidx[None, :, None, None, None, None, None] * + C * D * E * F * G + + cidx[None, None, :, None, None, None, None] * + D * E * F * G + + didx[None, None, None, :, None, None, None] * + E * F * G + + eidx[None, None, None, None, :, None, None] * + F * G + + fidx[None, None, None, None, None, :, None] * + G + + gidx[None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] +test_shape7d = TestUtils.test_shape7d +dtype_mapping7d = { + 'int8': (fn_npu_int8_7d, torch.int8), + 'int16': (fn_npu_int16_7d, torch.int16), + 'int32': (fn_npu_int32_7d, torch.int32), + 'int64': (fn_npu_int64_7d, torch.int64), + 'float16': (fn_npu_fp16_7d, torch.float16), + 'float32': (fn_npu_fp32_7d, torch.float32), + 'bfloat16': (fn_npu_bf16_7d, torch.bfloat16), + 'bool': (fn_npu_bool_7d, torch.bool), +} + +testlist7d = [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape7d + for func, dtype in [dtype_mapping7d[sigtype]] +] + +@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist7d) +def test_npu_7d(testfunc, sigtype, dtype, shape): + x = 0 + output = 0 + if len(shape) == 7: + if dtype == torch.bool: + x = torch.full((shape[0], shape[1], shape[2], shape[3], shape[4], shape[5], shape[6]), 0, dtype=dtype).npu() + else: + x = torch.full(shape, 10, dtype=dtype).npu() + output = torch.randint(1, shape, dtype=dtype).npu() + testfunc[1, 1, 1](output, *shape, debug=True) + test_common.validate_cmp(sigtype, output, x) + + +@triton.jit +def fn_npu_bf16_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr, H: tl.constexpr): + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + hidx = tl.arange(0, H) + + ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.bfloat16) + + oidx = (aidx[:, None, None, None, None, None, None, None] * + B * C * D * E * F * G * H + + bidx[None, :, None, None, None, None, None, None] * + C * D * E * F * G * H + + cidx[None, None, :, None, None, None, None, None] * + D * E * F * G * H + + didx[None, None, None, :, None, None, None, None] * + E * F * G * H + + eidx[None, None, None, None, :, None, None, None] * + F * G * H + + fidx[None, None, None, None, None, :, None, None] * + G * H + + gidx[None, None, None, None, None, None, :, None] * + H + + hidx[None, None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_int8_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr, H: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + hidx = tl.arange(0, H) + + ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int8) + + oidx = (aidx[:, None, None, None, None, None, None, None] * + B * C * D * E * F * G * H + + bidx[None, :, None, None, None, None, None, None] * + C * D * E * F * G * H + + cidx[None, None, :, None, None, None, None, None] * + D * E * F * G * H + + didx[None, None, None, :, None, None, None, None] * + E * F * G * H + + eidx[None, None, None, None, :, None, None, None] * + F * G * H + + fidx[None, None, None, None, None, :, None, None] * + G * H + + gidx[None, None, None, None, None, None, :, None] * + H + + hidx[None, None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_int16_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr, H: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + hidx = tl.arange(0, H) + + ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int16) + + oidx = (aidx[:, None, None, None, None, None, None, None] * + B * C * D * E * F * G * H + + bidx[None, :, None, None, None, None, None, None] * + C * D * E * F * G * H + + cidx[None, None, :, None, None, None, None, None] * + D * E * F * G * H + + didx[None, None, None, :, None, None, None, None] * + E * F * G * H + + eidx[None, None, None, None, :, None, None, None] * + F * G * H + + fidx[None, None, None, None, None, :, None, None] * + G * H + + gidx[None, None, None, None, None, None, :, None] * + H + + hidx[None, None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int32_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr, H: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + hidx = tl.arange(0, H) + + ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int32) + + oidx = (aidx[:, None, None, None, None, None, None, None] * + B * C * D * E * F * G * H + + bidx[None, :, None, None, None, None, None, None] * + C * D * E * F * G * H + + cidx[None, None, :, None, None, None, None, None] * + D * E * F * G * H + + didx[None, None, None, :, None, None, None, None] * + E * F * G * H + + eidx[None, None, None, None, :, None, None, None] * + F * G * H + + fidx[None, None, None, None, None, :, None, None] * + G * H + + gidx[None, None, None, None, None, None, :, None] * + H + + hidx[None, None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_int64_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr, H: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + hidx = tl.arange(0, H) + + ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.int64) + + oidx = (aidx[:, None, None, None, None, None, None, None] * + B * C * D * E * F * G * H + + bidx[None, :, None, None, None, None, None, None] * + C * D * E * F * G * H + + cidx[None, None, :, None, None, None, None, None] * + D * E * F * G * H + + didx[None, None, None, :, None, None, None, None] * + E * F * G * H + + eidx[None, None, None, None, :, None, None, None] * + F * G * H + + fidx[None, None, None, None, None, :, None, None] * + G * H + + gidx[None, None, None, None, None, None, :, None] * + H + + hidx[None, None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_fp16_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr, H: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + hidx = tl.arange(0, H) + + ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.float16) + + oidx = (aidx[:, None, None, None, None, None, None, None] * + B * C * D * E * F * G * H + + bidx[None, :, None, None, None, None, None, None] * + C * D * E * F * G * H + + cidx[None, None, :, None, None, None, None, None] * + D * E * F * G * H + + didx[None, None, None, :, None, None, None, None] * + E * F * G * H + + eidx[None, None, None, None, :, None, None, None] * + F * G * H + + fidx[None, None, None, None, None, :, None, None] * + G * H + + gidx[None, None, None, None, None, None, :, None] * + H + + hidx[None, None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_fp32_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr, H: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + hidx = tl.arange(0, H) + + ret = tl.full((A, B, C, D, E, F, G, H), value=10, dtype=tl.float32) + + oidx = (aidx[:, None, None, None, None, None, None, None] * + B * C * D * E * F * G * H + + bidx[None, :, None, None, None, None, None, None] * + C * D * E * F * G * H + + cidx[None, None, :, None, None, None, None, None] * + D * E * F * G * H + + didx[None, None, None, :, None, None, None, None] * + E * F * G * H + + eidx[None, None, None, None, :, None, None, None] * + F * G * H + + fidx[None, None, None, None, None, :, None, None] * + G * H + + gidx[None, None, None, None, None, None, :, None] * + H + + hidx[None, None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_bool_8d(output_ptr, A: tl.constexpr, B: tl.constexpr, C: tl.constexpr, + D: tl.constexpr, E: tl.constexpr, F: tl.constexpr, + G: tl.constexpr, H: tl.constexpr): + + aidx = tl.arange(0, A) + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + eidx = tl.arange(0, E) + fidx = tl.arange(0, F) + gidx = tl.arange(0, G) + hidx = tl.arange(0, H) + + ret = tl.full((A, B, C, D, E, F, G, H), value=0, dtype=tl.int1) + + oidx = (aidx[:, None, None, None, None, None, None, None] * + B * C * D * E * F * G * H + + bidx[None, :, None, None, None, None, None, None] * + C * D * E * F * G * H + + cidx[None, None, :, None, None, None, None, None] * + D * E * F * G * H + + didx[None, None, None, :, None, None, None, None] * + E * F * G * H + + eidx[None, None, None, None, :, None, None, None] * + F * G * H + + fidx[None, None, None, None, None, :, None, None] * + G * H + + gidx[None, None, None, None, None, None, :, None] * + H + + hidx[None, None, None, None, None, None, None, :]) + + tl.store(output_ptr + oidx, ret) + +test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] +test_shape8d = TestUtils.test_shape8d +dtype_mapping8d = { + 'int8': (fn_npu_int8_8d, torch.int8), + 'int16': (fn_npu_int16_8d, torch.int16), + 'int32': (fn_npu_int32_8d, torch.int32), + 'int64': (fn_npu_int64_8d, torch.int64), + 'float16': (fn_npu_fp16_8d, torch.float16), + 'float32': (fn_npu_fp32_8d, torch.float32), + 'bfloat16': (fn_npu_bf16_8d, torch.bfloat16), + 'bool': (fn_npu_bool_8d, torch.bool), +} + +testlist8d = [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape8d + for func, dtype in [dtype_mapping8d[sigtype]] +] + + +@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist8d) +def test_npu_8d(testfunc, sigtype, dtype, shape): + x = 0 + output = 0 + if len(shape) == 8: + if dtype == torch.bool: + x = torch.full((shape[0], shape[1], shape[2], shape[3], shape[4], shape[5], shape[6], shape[7]), 0, dtype=dtype).npu() + else: + x = torch.full(shape, 10, dtype=dtype).npu() + output = torch.randint(1, shape, dtype=dtype).npu() + testfunc[1, 1, 1](output, *shape, debug=True) + test_common.validate_cmp(sigtype, output, x) diff --git a/third_party/ascend/examples/generalization_cases/test_ge_op.py b/third_party/ascend/examples/generalization_cases/test_ge_op.py new file mode 100644 index 000000000..ded77051d --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_ge_op.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def triton_ge_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 >= x1 + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_ge_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): + moffs = tl.program_id(0) * M + mblk_idx = tl.arange(0, M) + moffs + nblk_idx = tl.arange(0, N) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 >= x1 + odx = mblk_idx[:, None] * N + nblk_idx[None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_ge_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): + lblk_idx = tl.arange(0, L) + idx = lblk_idx[:] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 >= x1 + odx = lblk_idx[:] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_ge_4d_5d( + x_ptr, y_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val >= y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] + +dtype_mapping = { + 'int8': (torch.int8), + 'int16': (torch.int16), + 'int32': (torch.int32), + 'uint32': (torch.uint32), + 'int64': (torch.int64), + 'float16': (torch.float16), + 'float32': (torch.float32), + 'bfloat16': (torch.bfloat16), + 'bool': (torch.bool), +} + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('sigtype', typelist) +def test_ge(sigtype, shape): + dtype = dtype_mapping[sigtype] + x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() + x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() + # ncore, xblock, xblock_sub = 2, 32768, 1024 + y_ref = torch.where(torch.ge(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) + output = torch.zeros(shape, dtype=dtype).npu() + if len(shape) == 3: + triton_ge_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) + if len(shape) == 2: + shape0 = shape[0] + shape1 = shape[1] + if x0.numel() * x0.element_size() >= 8192: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_ge_2d[grid](x0, x1, output, shape0, shape1) + if len(shape) == 1: + triton_ge_1d[1, 1, 1](x0, x1, output, shape[0]) + test_common.validate_cmp(sigtype, output, y_ref) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +def test_ge_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.where(torch.ge(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_ge_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_add.py b/third_party/ascend/examples/generalization_cases/test_general_add.py new file mode 100644 index 000000000..40417f97c --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_add.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# Only floating point clamp is supported +import pytest + +import triton +import triton.language as tl +import torch +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def triton_add(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X + Y + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_add_broadcast(in_ptr0, in_ptr1, out_ptr0, X_SHAPE_0: tl.constexpr, + X_SHAPE_1: tl.constexpr, X_SHAPE_2: tl.constexpr, X_SHAPE_3: tl.constexpr, X_SHAPE_4: tl.constexpr, + Y_SHAPE_0: tl.constexpr, Y_SHAPE_1: tl.constexpr, Y_SHAPE_2: tl.constexpr, Y_SHAPE_3: tl.constexpr, + Y_SHAPE_4: tl.constexpr): + x_idx0 = tl.arange(0, X_SHAPE_0) + x_idx1 = tl.arange(0, X_SHAPE_1) + x_idx2 = tl.arange(0, X_SHAPE_2) + x_idx3 = tl.arange(0, X_SHAPE_3) + x_idx4 = tl.arange(0, X_SHAPE_4) + + y_idx0 = tl.arange(0, Y_SHAPE_0) + y_idx1 = tl.arange(0, Y_SHAPE_1) + y_idx2 = tl.arange(0, Y_SHAPE_2) + y_idx3 = tl.arange(0, Y_SHAPE_3) + y_idx4 = tl.arange(0, Y_SHAPE_4) + + xidx = x_idx0[:, None, None, None, None] * X_SHAPE_1 * X_SHAPE_2 * X_SHAPE_3 * X_SHAPE_4 + \ + x_idx1[None, :, None, None, None] * X_SHAPE_2 * X_SHAPE_3 * X_SHAPE_4 + \ + x_idx2[None, None, :, None, None] * X_SHAPE_3 * X_SHAPE_4 + \ + x_idx3[None, None, None, :, None] * X_SHAPE_4 + x_idx4[None, None, None, None, :] + + yidx = y_idx0[:, None, None, None, None] * Y_SHAPE_1 * Y_SHAPE_2 * Y_SHAPE_3 * Y_SHAPE_4 + \ + y_idx1[None, :, None, None, None] * Y_SHAPE_2 * Y_SHAPE_3 * Y_SHAPE_4 + \ + y_idx2[None, None, :, None, None] * Y_SHAPE_3 * Y_SHAPE_4 + \ + y_idx3[None, None, None, :, None] * Y_SHAPE_4 + y_idx4[None, None, None, None, :] + + X = tl.load(in_ptr0 + xidx) + Y = tl.load(in_ptr1 + yidx) + ret = X + Y + + tl.store(out_ptr0 + xidx, ret) + + +@triton.jit +def triton_add_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val + y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_add(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + + ans = x + y + output = torch.zeros_like(ans) + + if len(shape) == 1: + triton_add[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + triton_add[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + triton_add[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + triton_add[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + triton_add[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + triton_add[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + triton_add[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_add_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x + y + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_add_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) + + +def promote_dtype(x_dtype, y_dtype): + """ + 如果 y 的精度低于 x, 则提升 y 的精度以匹配 x。 + """ + # 如果两个数据类型一致,直接返回 + if x_dtype == y_dtype: + return y_dtype + + # 构建类型的优先级列表(从低到高) + priority = [ + torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, + torch.float16, torch.bfloat16, torch.float32 + ] + + # 查找两种类型在优先级列表中的位置 + x_priority = priority.index(x_dtype) + y_priority = priority.index(y_dtype) + + # 如果y的优先级比x小,则提升到x的类型 + if y_priority < x_priority: + return x_dtype + else: + return y_dtype + + +@pytest.mark.parametrize('param_list', + [ + [(5, 1, 1, 1, 1), (5, 1, 1, 2, 1)], + [(2, 1), (2, 4)], + [(2, 1, 1), (2, 4, 2)], + [(2, 1, 1, 1), (2, 4, 2, 2)], + [(2, 1, 1, 1, 1), (2, 4, 2, 2, 2)], + [(1, ), (4, )], + [(1, 2, 1), (1, 2, 3)], + [(1, 1, 1, 1), (7, 1, 1, 1)] + ] + ) +@pytest.mark.parametrize('x_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) +@pytest.mark.parametrize('y_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) +def test_add_broadcast(param_list, x_dtype_str, y_dtype_str): + x_shape, y_shape = param_list + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + x = test_common.generate_tensor(x_shape, x_dtype_str).npu() + y = test_common.generate_tensor(y_shape, y_dtype_str).npu() + if y.numel() > x.numel(): + tmp = y + y = x + x = tmp + ans = x + y + while x.dim() < 5: + x = x.unsqueeze(-1) + while y.dim() < 5: + y = y.unsqueeze(-1) + bf2fpFlag = False + out_dtype = promote_dtype(x_dtype, y_dtype) + if (x_dtype == torch.bfloat16 and y_dtype == torch.float16) or \ + (x_dtype == torch.float16 and y_dtype == torch.bfloat16): + out_dtype = torch.float32 + bf2fpFlag = True + out_dtype = str(out_dtype).split('.')[-1] + out = test_common.generate_tensor(x.shape, out_dtype).npu() + + triton_add_broadcast[1, 1, 1](x, y, out, *x.shape, *y.shape) + while out.dim() > ans.dim(): + out = out.squeeze(-1) + + if bf2fpFlag: + torch.testing.assert_close(out, ans, rtol=1e-3, atol=1e-3) + else: + torch.testing.assert_close(out, ans) + + + +@triton.jit +def add_5d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr, NB1: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1 * NB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1 * NB1) + offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1 * NB1) + offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] * NB1 + offsets1 = offsets1[:, :, :, :, None] + tl.arange(0, NB1)[None, None, None, None, :] + + tmp0 = tl.load(x_ptr + offsets) + tmp1 = tl.load(y_ptr + offsets1) + tmp2 = tl.load(out_ptr + offsets1) + out = tmp2 + tmp1 + tmp0 + tl.store(out_ptr + offsets1, out) + + +@triton.jit +def add_4d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr, MB1: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB) + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB) + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] + + offsets1 = tl.arange(0, XB1) * (YB1 * ZB1 * MB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1 * MB1) + offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] * (MB1) + offsets1 = offsets1[:, :, :, None] + tl.arange(0, MB1)[None, None, None, :] + + tmp0 = tl.load(x_ptr + offsets) + tmp1 = tl.load(y_ptr + offsets1) + tmp2 = tl.load(out_ptr + offsets1) + out = tmp2 + tmp1 + tmp0 + tl.store(out_ptr + offsets1, out) + + +@triton.jit +def add_3d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr, ZB1: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB) + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] + + offsets1 = tl.arange(0, XB1) * (YB1 * ZB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] * (ZB1) + offsets1 = offsets1[:, :, None] + tl.arange(0, ZB1)[None, None, :] + + tmp0 = tl.load(x_ptr + offsets) + tmp1 = tl.load(y_ptr + offsets1) + tmp2 = tl.load(out_ptr + offsets1) + out = tmp2 + tmp1 + tmp0 + tl.store(out_ptr + offsets1, out) + + +@triton.jit +def add_2d(x_ptr, y_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, + XB1: tl.constexpr, YB1: tl.constexpr): + offsets = tl.arange(0, XB) * (YB) + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] + + offsets1 = tl.arange(0, XB1) * (YB1) + offsets1 = offsets1[:, None] + tl.arange(0, YB1)[None, :] + + tmp0 = tl.load(x_ptr + offsets) + tmp1 = tl.load(y_ptr + offsets1) + tmp2 = tl.load(out_ptr + offsets1) + out = tmp2 + tmp1 + tmp0 + tl.store(out_ptr + offsets1, out) + + +@pytest.mark.parametrize('param_list', + [ + [(5, 1, 1, 1, 1), (5, 1, 1, 2, 1)], + ] + ) +@pytest.mark.parametrize('x_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) +@pytest.mark.parametrize('y_dtype_str', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) +def test_add_2d_to_5d(x_dtype_str, y_dtype_str, param_list): + x0_shape, y_shape = param_list + ndim = max(len(x0_shape), len(y_shape)) + # 获取原始类型 + x_dtype = eval('torch.' + x_dtype_str) + y_dtype = eval('torch.' + y_dtype_str) + + x0 = test_common.generate_tensor(x0_shape, x_dtype_str).npu() + y = test_common.generate_tensor(y_shape, y_dtype_str).npu() + + out_dtype = promote_dtype(x_dtype, y_dtype) + if out_dtype == torch.bfloat16: + out_dtype = torch.float32 + out = torch.full(y_shape, 0, dtype=out_dtype).npu() + + x0_temp = x0.clone() + y_temp = y.clone() + out_temp = out.clone() + + triton_shape = [*x0_shape] + while len(triton_shape) < ndim: + triton_shape.append(1) + + triton_shape1 = [*y_shape] + while len(triton_shape1) < ndim: + triton_shape1.append(1) + + # 按维度分支 + if ndim == 2: + XB, YB = triton_shape + XB1, YB1 = triton_shape1 + + add_2d[(1, )]( + x_ptr=x0, + y_ptr=y, + out_ptr=out, + XB=XB, YB=YB, + XB1=XB1, YB1=YB1, + ) + + elif ndim == 3: + XB, YB, ZB = triton_shape + XB1, YB1, ZB1 = triton_shape1 + + add_3d[(1, )]( + x_ptr=x0, + y_ptr=y, + out_ptr=out, + XB=XB, YB=YB, ZB=ZB, + XB1=XB1, YB1=YB1, ZB1=ZB1, + ) + + elif ndim == 4: + XB, YB, ZB, MB = triton_shape + XB1, YB1, ZB1, MB1 = triton_shape1 + + add_4d[(1, )]( + x_ptr=x0, + y_ptr=y, + out_ptr=out, + XB=XB, YB=YB, ZB=ZB, MB=MB, + XB1=XB1, YB1=YB1, ZB1=ZB1, MB1=MB1, + ) + + elif ndim == 5: + XB, YB, ZB, MB, NB = triton_shape + XB1, YB1, ZB1, MB1, NB1 = triton_shape1 + + add_5d[(1, )]( + x_ptr=x0, + y_ptr=y, + out_ptr=out, + XB=XB, YB=YB, ZB=ZB, MB=MB, NB=NB, + XB1=XB1, YB1=YB1, ZB1=ZB1, MB1=MB1, NB1=NB1, + ) + + else: + raise ValueError(f"Unsupported tensor dim: {ndim}") + expected = out_temp + y_temp + x0_temp + torch.testing.assert_close(out, expected) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_general_arange.py b/third_party/ascend/examples/generalization_cases/test_general_arange.py new file mode 100644 index 000000000..e76267d49 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_arange.py @@ -0,0 +1,44 @@ +import math +import pytest +import torch +import triton +import triton.language as tl + +import test_common +from test_common import TestUtils + + +def torch_pointwise(length): + res = (torch.arange(0, length) / 2.7) * torch.arange(0, length) + return res + + +@triton.jit +def triton_arange(out_ptr0, length: tl.constexpr, numel: tl.constexpr): + offs = tl.program_id(0) * length + idx = offs + tl.arange(0, length) + a = idx / 2.7 + b = idx * a + mask = idx < numel + tl.store(out_ptr0 + idx, b, mask) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['int32', 'int16', 'int8', 'int64']) +def test_case(dtype, shape): + x0 = test_common.generate_tensor(shape, dtype).npu() + + numel = x0.numel() + ncore = 32 if dtype == 'int8' and numel > 127 else 1 + if dtype in ('float16', 'bfloat16', 'float32', 'bool'): + # tl.arange doesn't support float and bool + xblock = numel / ncore + else: + xblock = math.ceil(numel / ncore) + + y_ref = torch_pointwise(numel) + y_cal = torch.zeros(shape, dtype=torch.float32).npu() + + triton_arange[ncore, 1, 1](y_cal, xblock, numel) + + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/generalization_cases/test_general_cat.py b/third_party/ascend/examples/generalization_cases/test_general_cat.py new file mode 100644 index 000000000..8d460a7d4 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_cat.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, XB: tl.constexpr): + + idx = tl.arange(0, XB) + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.cat(X, Y, can_reorder=True) + + oidx = tl.arange(0, XB * 2) + + tl.store(output_ptr + oidx, ret) + +# The CAT operator in the Triton community also does not support boolean types. +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) #triton only support 1D cat +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int32', 'int16', 'int8', 'int64']) +def test_cat(shape, dtype): + m = shape[0] + x = torch.full((m, ), 100, dtype=eval("torch." + dtype)).npu() + y = torch.full((m, ), 30, dtype=eval("torch." + dtype)).npu() + + output = torch.randint(1, (m * 2, ), dtype=eval("torch." + dtype)).npu() + + ans = torch.cat((x, y), dim=0) + + fn_npu_[1, 1, 1](output, x, y, m) + + test_common.validate_cmp(dtype, ans, output) + diff --git a/third_party/ascend/examples/generalization_cases/test_general_clamp.py b/third_party/ascend/examples/generalization_cases/test_general_clamp.py new file mode 100644 index 000000000..a23aa9bec --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_clamp.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# Only floating point clamp is supported +import pytest + +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils +import logging + + +def torch_clamp(x0, min_, max_): + res = torch.clamp(x0, min_, max_) + return res + + +@triton.jit +def tt_clamp_1d(in_ptr, out_ptr, min_ptr, max_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + idx = tl.arange(0, XB) + + x = tl.load(in_ptr + idx) + min_ = tl.load(min_ptr + idx) + max_ = tl.load(max_ptr + idx) + ret = tl.clamp(x, min_, max_) + + tl.store(out_ptr + idx, ret) + + +@triton.jit +def tt_clamp_2d(in_ptr, out_ptr, min_ptr, max_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + idx = xidx[:, None] * ynumel + yidx[None, :] + + x = tl.load(in_ptr + idx) + min_ = tl.load(min_ptr + idx) + max_ = tl.load(max_ptr + idx) + ret = tl.clamp(x, min_, max_) + + tl.store(out_ptr + idx, ret) + + +@triton.jit +def tt_clamp_3d(in_ptr, out_ptr, min_ptr, max_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + min_ = tl.load(min_ptr + idx) + max_ = tl.load(max_ptr + idx) + ret = tl.clamp(x, min_, max_) + + tl.store(out_ptr + idx, ret) + + +@triton.jit +def triton_clamp_4d_5d( + x_ptr, output_ptr, min_ptr, max_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + min_ = tl.load(min_ptr + offsets) + max_ = tl.load(max_ptr + offsets) + ret = tl.clamp(x_val, min_, max_) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_clamp(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + torch.manual_seed(0) + x = test_common.generate_tensor(shape, dtype).npu() + a = test_common.generate_tensor(shape, dtype) + b = test_common.generate_tensor(shape, dtype) + min_ = torch.min(a, b).npu() + max_ = torch.max(a, b).npu() + + grid = (1, 1, 1) + + y_cal = torch.empty(shape, dtype=eval('torch.' + dtype), device="npu") + + y_ref = torch_clamp(x, min_, max_) + if len(shape) == 1: + tt_clamp_1d[grid](x, y_cal, min_, max_, x.numel(), 1, 1, x.numel(), 1, 1) + elif len(shape) == 2: + xnumel, ynumel, znumel = shape + (1,) + XB, YB, ZB = xnumel, ynumel, znumel + if x.numel() * x.element_size() > 8192: + grid = (1, ynumel, 1) + YB = 1 + tt_clamp_2d[grid](x, y_cal, min_, max_, xnumel, ynumel, znumel, XB, YB, ZB) + + elif len(shape) == 3: + xnumel, ynumel, znumel = shape + XB, YB, ZB = xnumel, ynumel, znumel + tt_clamp_3d[grid](x, y_cal, min_, max_, xnumel, ynumel, znumel, XB, YB, ZB) + + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_clamp_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + torch.manual_seed(0) + x = test_common.generate_tensor(shape, dtype).npu() + a = test_common.generate_tensor(shape, dtype) + b = test_common.generate_tensor(shape, dtype) + min_ = torch.min(a, b).npu() + max_ = torch.max(a, b).npu() + + output = torch.empty(shape, dtype=eval('torch.' + dtype), device="npu") + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_clamp(x, min_, max_) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_clamp_4d_5d[grid](x, output, min_, max_, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_div.py b/third_party/ascend/examples/generalization_cases/test_general_div.py new file mode 100644 index 000000000..4f21e300e --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_div.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# Only floating point clamp is supported +import pytest + +import triton +import triton.language as tl +import torch +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def triton_div(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X / Y + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_div_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val / y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub +@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_div(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + y[y == 0] = 1 + + ans = x / y + output = torch.zeros_like(ans) + if len(shape) == 1: + triton_div[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + triton_div[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + triton_div[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + triton_div[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + triton_div[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + triton_div[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + triton_div[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + # change dtype beacuse of triton processing, triton div op will change from int to float + if dtype in ['int8', 'int16', 'int32', 'int64']: + dtype = 'float32' + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_div_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + y[y == 0] = 1 + + new_shape = shape + if dtype == 'int8' or dtype == 'int16' or dtype == 'int32' or dtype == 'int64': + output = torch.randint(1, new_shape, dtype=eval('torch.float32')).npu() + dtype = 'float32' + else: + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + + ans = x / y + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_div_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_general_floor.py b/third_party/ascend/examples/generalization_cases/test_general_floor.py new file mode 100644 index 000000000..81829303a --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_floor.py @@ -0,0 +1,134 @@ +import triton +import triton.language as tl +import torch +import logging +import pytest +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X + tl.floor(Y) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_floor_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val + tl.floor(y_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_floor(dtype, shape): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() + y = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x + torch.floor(y) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_floor_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x + torch.floor(y) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_floor_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_floordiv.py b/third_party/ascend/examples/generalization_cases/test_general_floordiv.py new file mode 100644 index 000000000..4266f8cf7 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_floordiv.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# Only floating point clamp is supported +import pytest + +import triton +import triton.language as tl +import torch +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def triton_floordiv(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X // Y + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_floordiv_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val // y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) +def test_floordiv(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + z = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + + new_shape = shape + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + y[y == 0] = 1 + ans = x // y + ans_mask = (x.to(torch.int64) % y.to(torch.int64) != 0) & (~((x ^ y) > 0)).to(ans.dtype) + ans = ans + ans_mask + + if len(shape) == 1: + triton_floordiv[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + triton_floordiv[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + triton_floordiv[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + triton_floordiv[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + triton_floordiv[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + triton_floordiv[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + triton_floordiv[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) +def test_floordiv_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + new_shape = shape + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + y[y == 0] = 1 + ans = x // y + ans_mask = (x.to(torch.int64) % y.to(torch.int64) != 0) & (~((x ^ y) > 0)).to(ans.dtype) + ans = ans + ans_mask + + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_floordiv_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) + +invalid_types = [ + 'bool', + 'float16', + 'float32', + 'bfloat16', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = y.masked_fill(y == 0, 1) + z = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + triton_floordiv[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) diff --git a/third_party/ascend/examples/generalization_cases/test_general_fma.py b/third_party/ascend/examples/generalization_cases/test_general_fma.py new file mode 100644 index 000000000..34e7f75ec --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_fma.py @@ -0,0 +1,147 @@ +import triton +import triton.language as tl +import torch +import logging +import pytest +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + Z = tl.load(z_ptr + idx) + + ret = tl.fma(X, Y, Z) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_fma_4d_5d( + output_ptr, x_ptr, y_ptr, z_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + z_val = tl.load(z_ptr + offsets, masks) + ret = tl.fma(x_val, y_val, z_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) # math.fma do not support int dtype +def test_fma(dtype, shape): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + if (x.dtype == torch.bfloat16): + ans = x.to(torch.float32) * y.to(torch.float32) + z.to(torch.float32) + ans = ans.to(torch.bfloat16) + else: + ans = x * y + z + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_fma_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + if (x.dtype == torch.bfloat16): + ans = x.to(torch.float32) * y.to(torch.float32) + z.to(torch.float32) + ans = ans.to(torch.bfloat16) + else: + ans = x * y + z + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_fma_4d_5d[grid](output, x, y, z, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_gather.py b/third_party/ascend/examples/generalization_cases/test_general_gather.py new file mode 100644 index 000000000..0a39b0fc5 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_gather.py @@ -0,0 +1,195 @@ +import math +import numpy as np +import torch +import torch_npu +import triton +import triton.language as tl +import triton.language.extra.ascend.libdevice as libdevice +import test_common +import pytest +from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size + + +@pytest.mark.parametrize("src_shape, indices_shape, axis", [ + ([2, 2], [4, 2], 0), + ([3, 3], [1, 3], 0), + ([3, 4], [4, 4], 0), + ([4, 4], [8, 4], 0), + ([4, 32], [4, 16], 1), + ([4, 64], [4, 32], 1), + ([128, 64], [128, 128], 1), +]) +def test_gather(src_shape, indices_shape, axis): + @triton.jit + def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + gather_kernel[(1, )](src, indices, output, axis, + src.shape[0], src.shape[1], + src.stride(0), src.stride(1), + indices.shape[0], indices.shape[1], + indices.stride(0), indices.stride(1), + output.shape[0], output.shape[1], + output.stride(0), output.stride(1)) + return output + + DEV = "npu" + src = torch.randn(src_shape, device=DEV) + indices = torch.randint(0, src.shape[axis], indices_shape, device=DEV) + + dtype_size = get_dtype_size('int32') + if dtype_size * math.prod(src.shape) >= (TestUtils.ub_size / 8): + print(f"dtype:int32 shape:{src.shape} mem overflow") + return + + ref = torch.gather(src, axis, indices) + result = triton_gather(src, axis, indices) + torch.testing.assert_close(result, ref, rtol=0, atol=0) + + +@pytest.mark.parametrize('param_list', + [ + ['float16', (11, 12, 256, 512), 48], + ['bfloat16', (11, 12, 256, 512), 48], + ['float32', (11, 12, 256, 512), 48], + ]) +def test_gather_flip(param_list): + + def torch_func(inp, idx): + return torch.gather(input=inp, dim=-1, index=idx) + + @triton.jit + def triton_kernel(dst_ptr, src_ptr, idx_ptr, + XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr, + R0_BLOCK: tl.constexpr, R1_BLOCK: tl.constexpr): + pid = tl.program_id(0) + poff = pid * XBLOCK + x0_idx_base = 0 + r1_idx = tl.arange(0, R1_BLOCK) + loop0 = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for xsub_id in tl.range(loop0): + x0_idx = poff + xsub_id * XBLOCK_SUB + x0_idx_base + idx_idx = idx_ptr + x0_idx * R1_BLOCK + r1_idx + idx_blk = tl.load(idx_idx) + idx_min = tl.min(idx_blk, axis=0) + src_idx = src_ptr + x0_idx * R0_BLOCK + idx_min + r1_idx + src_blk = tl.load(src_idx) + fliped_blk = libdevice.flip(src_blk, 0) + dst_idx = dst_ptr + x0_idx * R1_BLOCK + r1_idx + tl.store(dst_idx, fliped_blk) + + def triton_func(p2c_out, p2c_att, p2c_pos, ncore): + nrows = p2c_att.shape[0] * p2c_att.shape[1] * p2c_att.shape[2] + xs = nrows // ncore + assert(xs * ncore == nrows) + xss = 1 # must be 1 + r0s = p2c_att.shape[3] + r1s = p2c_att.shape[2] + triton_kernel[ncore, 1, 1](p2c_out, p2c_att, p2c_pos, + XBLOCK=xs, XBLOCK_SUB=xss, + R0_BLOCK=r0s, R1_BLOCK=r1s) + return p2c_out + + dtype, shape, ncore = param_list + M0, M1, N0, N1 = shape + r0 = torch.arange(N0) + c0 = torch.arange(N0) + p2c_pos = r0[:, None] - c0[None, :] + N0-1 + p2c_pos = p2c_pos.broadcast_to((M0, M1, N0, N0)) + p2c_pos = p2c_pos.npu() + if (p2c_pos.dtype == torch.int64): + p2c_pos = p2c_pos.to(torch.int32) + assert(np.all(np.diff(p2c_pos.cpu()) == -1)) + p2c_att = test_common.generate_tensor(shape, dtype).npu() + p2c_out = test_common.generate_tensor(p2c_pos.shape, dtype).npu() + + p2c_ref = torch_func(p2c_att, p2c_pos) + triton_func(p2c_out, p2c_att, p2c_pos, ncore) + test_common.validate_cmp(dtype, p2c_out, p2c_ref) + + +@triton.jit +def gather_kernel_multi_d(src_ptr, idx_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, I_XB: tl.constexpr, I_YB: tl.constexpr, I_ZB: tl.constexpr, I_MB: tl.constexpr, I_NB: tl.constexpr, DIMS: tl.constexpr, AXIS: tl.constexpr): + in_offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if DIMS > 1: + in_offsets = in_offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if DIMS > 2: + in_offsets = in_offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if DIMS > 3: + in_offsets = in_offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if DIMS > 4: + in_offsets = in_offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + idx_offsets = tl.arange(0, I_XB) * (I_YB * I_ZB * I_MB * I_NB) + if DIMS > 1: + idx_offsets = idx_offsets[:, None] + tl.arange(0, I_YB)[None, :] * (I_ZB * I_MB * I_NB) + if DIMS > 2: + idx_offsets = idx_offsets[:, :, None] + tl.arange(0, I_ZB)[None, None, :] * (I_MB * I_NB) + if DIMS > 3: + idx_offsets = idx_offsets[:, :, :, None] + tl.arange(0, I_MB)[None, None, None, :] * I_NB + if DIMS > 4: + idx_offsets = idx_offsets[:, :, :, :, None] + tl.arange(0, I_NB)[None, None, None, None, :] + + src = tl.load(src_ptr + in_offsets) + idx = tl.load(idx_ptr + idx_offsets) + + out = tl.gather(src, idx, AXIS) + + tl.store(out_ptr + idx_offsets, out) + + +def triton_gather_multi_d(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + + s_shape = [*(src.shape)] + while len(s_shape) < 5: + s_shape.append(1) + i_shape = [*(indices.shape)] + while len(i_shape) < 5: + i_shape.append(1) + gather_kernel_multi_d[(1, )](src, indices, output, *s_shape, *i_shape, len(src.shape), axis) + return output + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize("src_shape, indices_shape, axis", [ + ((2, 2, 4, 8), (2, 2, 4, 8), 0), + ((2, 2, 4, 1), (2, 2, 4, 1), 3), + ((2, 3, 4, 8), (2, 3, 4, 8), 1), + ((2, 3, 4, 8), (2, 3, 4, 8), 2), + ((2, 2, 2, 4, 1), (2, 2, 2, 4, 1), 4), + ((2, 2, 2, 4, 8), (2, 2, 2, 4, 8), 1), + ((2, 2, 3, 4, 8), (2, 2, 3, 4, 8), 2), + ((2, 2, 3, 4, 8), (2, 2, 3, 4, 8), 0), +]) +def test_gather_4d_5d(src_shape, indices_shape, axis): + DEV = "npu" + src = torch.randn(src_shape, device=DEV) + indices = torch.randint(0, src.shape[axis], indices_shape, device=DEV) + + ref = torch.gather(src, axis, indices) + result = triton_gather_multi_d(src, axis, indices) + torch.testing.assert_close(result, ref, rtol=0, atol=0) + + +if __name__ == "__main__": + param_list = ['float16', (11, 12, 256, 512), 48] + test_gather_flip(param_list) + print("success: test_gather_flip") + test_gather([4, 64], [4, 32], 1) + print("success: test_gather") diff --git a/third_party/ascend/examples/generalization_cases/test_general_interleave.py b/third_party/ascend/examples/generalization_cases/test_general_interleave.py new file mode 100644 index 000000000..f3acf6fb0 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_interleave.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import logging +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + zoffs2 = tl.program_id(2) * ZB * 2 + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + zidx2 = tl.arange(0, 2 * ZB) + zoffs2 + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.interleave(X, Y) + + oidx = xidx[:, None, None] * YNUMEL * ZNUMEL * 2 + yidx[None, :, None] * ZNUMEL * 2 + zidx2[None, None, :] + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def triton_interleave_4d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, +): + pid = tl.program_id(0) + tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] + tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] + tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] + tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] + tmp4 = tl.arange(0, 2 * BLOCK_3)[None, None, None, :] + offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + + ret = tl.interleave(x_val, y_val) + + out_offsets = pid + tmp0 * STRIDE_0 * 2 + tmp1 * STRIDE_1 * 2 + tmp2 * STRIDE_2 * 2 + tmp4 * STRIDE_3 + out_masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp4 < 2 * SHAPE_3) + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@triton.jit +def triton_interleave_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + pid = tl.program_id(0) + tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] + tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] + tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] + tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] + tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] + tmp5 = tl.arange(0, 2 * BLOCK_4)[None, None, None, None, :] + offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 + masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + + ret = tl.interleave(x_val, y_val) + + out_offsets = pid + tmp0 * STRIDE_0 * 2 + tmp1 * STRIDE_1 * 2 + tmp2 * STRIDE_2 * 2 + tmp3 * STRIDE_3 * 2 + tmp5 * STRIDE_4 + out_masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp5 < 2 * SHAPE_4) + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_interleave(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() + y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() + new_shape = shape[:-1] + (2 * shape[-1],) + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.stack((x, y), dim=-1).reshape(new_shape) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_interleave_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape[:-1] + (2 * shape[-1],) + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.stack((x, y), dim=-1).reshape(new_shape) + + blocks = list(x.size()) + strides = list(x.stride()) + + grid = (1,) + if len(shape) == 4: + triton_interleave_4d[grid](output, x, y, *blocks, *blocks, *strides) + else: + triton_interleave_5d[grid](output, x, y, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_join.py b/third_party/ascend/examples/generalization_cases/test_general_join.py new file mode 100644 index 000000000..23f0e7455 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_join.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + ret = tl.join(X, Y) + + oidx = xidx[:, None, None, None] * YNUMEL * ZNUMEL * 2 + yidx[None, :, None, None] * ZNUMEL * 2 + \ + zidx[None, None, :, None] * 2 + tl.arange(0, 2)[None, None, None, :] + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def triton_join_4d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, +): + pid = tl.program_id(0) + tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] + tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] + tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] + tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] + + offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + + ret = tl.join(x_val, y_val) + + out_tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] + out_tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] + out_tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] + out_tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] + out_tmp4 = tl.arange(0, 2)[None, None, None, None, :] + out_offsets = pid + out_tmp0 * STRIDE_0 * 2 + out_tmp1 * STRIDE_1 * 2 + out_tmp2 * STRIDE_2 * 2 \ + + out_tmp3 * STRIDE_3 * 2 + out_tmp4 + out_masks = (out_tmp0 < SHAPE_0) & (out_tmp1 < SHAPE_1) & (out_tmp2 < SHAPE_2) \ + & (out_tmp3 < SHAPE_3) & (out_tmp4 < 2) + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@triton.jit +def triton_join_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + pid = tl.program_id(0) + tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] + tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] + tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] + tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] + tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] + + offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 + masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + + ret = tl.join(x_val, y_val) + + out_tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None, None] + out_tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None, None] + out_tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None, None] + out_tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None, None] + out_tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :, None] + out_tmp5 = tl.arange(0, 2)[None, None, None, None, None, :] + out_offsets = pid + out_tmp0 * STRIDE_0 * 2 + out_tmp1 * STRIDE_1 * 2 + out_tmp2 * STRIDE_2 * 2 \ + + out_tmp3 * STRIDE_3 * 2 + out_tmp4 * STRIDE_4 * 2 + out_tmp5 + out_masks = (out_tmp0 < SHAPE_0) & (out_tmp1 < SHAPE_1) & (out_tmp2 < SHAPE_2) \ + & (out_tmp3 < SHAPE_3) & (out_tmp4 < SHAPE_4) & (out_tmp5 < 2) + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_join(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() + y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() + new_shape = shape + (2,) + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.stack((x, y), dim=-1) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_join_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape + (2,), dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.stack((x, y), dim=-1) + + blocks = list(x.size()) + strides = list(x.stride()) + + grid = (1,) + if len(shape) == 4: + triton_join_4d[grid](output, x, y, *blocks, *blocks, *strides) + else: + triton_join_5d[grid](output, x, y, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_log.py b/third_party/ascend/examples/generalization_cases/test_general_log.py new file mode 100644 index 000000000..edc8f267b --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_log.py @@ -0,0 +1,134 @@ +import logging +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import triton.language.extra.ascend.libdevice as libdevice +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.log(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_log_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.log(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_log(dtype, shape): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() + y = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() + z = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() + + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.log(x).to(eval('torch.' + dtype)) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_log_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.log(x).to(eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_log_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_log2.py b/third_party/ascend/examples/generalization_cases/test_general_log2.py new file mode 100644 index 000000000..d66178d50 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_log2.py @@ -0,0 +1,139 @@ +import logging +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import triton.language.extra.ascend.libdevice as libdevice +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.log2(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_log2_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.log2(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_log2(dtype, shape): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() + y = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() + z = torch.rand(shape, dtype=eval('torch.' + dtype)).npu() + + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.log2(x).to(eval('torch.' + dtype)) + + if len(shape) == 1: + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, shape[0], 1, 1, shape[0]) + elif len(shape) == 2: + fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + mx = max(shape[0], shape[1], shape[2]) + if mx == shape[0]: + fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif mx == shape[1]: + fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + + test_common.validate_cmp(dtype, ans, output) + + +invalid_dtypes = [ + 'int8', + 'int16', + 'int32', + 'uint32', + 'int64', + 'bool', +] + + +@pytest.mark.parametrize("dtype", invalid_dtypes) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") +def test_log2_invalid_dtype_case(dtype): + x = test_common.generate_tensor((1,), dtype).npu() + y = test_common.generate_tensor((1,), dtype).npu() + z = test_common.generate_tensor((1,), dtype).npu() + + output = torch.randint(1, (1,), dtype=eval('torch.' + dtype)).npu() + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_log2_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.log2(x).to(eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_log2_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_maximum.py b/third_party/ascend/examples/generalization_cases/test_general_maximum.py new file mode 100644 index 000000000..8d08d4409 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_maximum.py @@ -0,0 +1,138 @@ +import logging +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_maximum(x, y): + return torch.maximum(x, y) + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.maximum(X, Y) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_maximum_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = tl.maximum(x_val, y_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) +def test_maximum(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + + ans = torch_maximum(x, y) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) +def test_maximum_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_maximum(x, y) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_maximum_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_minimum.py b/third_party/ascend/examples/generalization_cases/test_general_minimum.py new file mode 100644 index 000000000..19228fb3b --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_minimum.py @@ -0,0 +1,139 @@ +import logging +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_minimum(x, y): + return torch.minimum(x, y) + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.minimum(X, Y) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_minimum_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = tl.minimum(x_val, y_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64', 'bool']) +def test_minimum(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_minimum(x, y) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool']) +def test_minimum_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_minimum(x, y) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_minimum_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) + diff --git a/third_party/ascend/examples/generalization_cases/test_general_mul.py b/third_party/ascend/examples/generalization_cases/test_general_mul.py new file mode 100644 index 000000000..3cdc873ee --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_mul.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# Only floating point clamp is supported +import pytest + +import triton +import triton.language as tl +import torch +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def triton_mul(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X * Y + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_mul_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val * y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) # some shape with int8 over ub +@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_mul(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + + ans = x * y + output = torch.zeros_like(ans) + + if len(shape) == 1: + triton_mul[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + triton_mul[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + triton_mul[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + triton_mul[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + triton_mul[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + triton_mul[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + triton_mul[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_mul_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x * y + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_mul_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_general_ravel.py b/third_party/ascend/examples/generalization_cases/test_general_ravel.py new file mode 100644 index 000000000..bd1b9ac13 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_ravel.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import logging +import pytest +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.ravel(X) + + oidx = tl.arange(0, XB * YB * ZB) + xoffs * YNUMEL * ZNUMEL + yoffs * ZNUMEL + zoffs + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def triton_ravel_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.ravel(x_val) + + pid0 = tl.program_id(0) + + flat_idx = tl.arange(0, BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) + out_offsets = pid0 * BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4 + flat_idx + out_masks = out_offsets < SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_ravel(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() + y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() + new_shape = (x.numel(),) + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.ravel(x) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + if xnumel > 1: + grid = (XB, 1, 1) + XB = 1 + elif ynumel > 1: + grid = (1, YB, 1) + YB = 1 + else: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_ravel_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() + + output = torch.randint(1, (x.numel(),), dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.ravel(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_ravel_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_reshape.py b/third_party/ascend/examples/generalization_cases/test_general_reshape.py new file mode 100644 index 000000000..3578d0195 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_reshape.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math +import logging + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.reshape(X, (ZB * YB * XB,)) + + oidx = tl.arange(0, XB * YB * ZB) + xoffs * YNUMEL * ZNUMEL + yoffs * ZNUMEL + zoffs + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def triton_reshape_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.reshape(x_val, (SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4,)) + + pid0 = tl.program_id(0) + + flat_idx = tl.arange(0, BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) + out_offsets = pid0 * BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4 + flat_idx + out_masks = out_offsets < SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_reshape(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() + y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() + new_shape = (x.numel(),) + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x.reshape(-1) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + if xnumel > 1: + grid = (XB, 1, 1) + XB = 1 + elif ynumel > 1: + grid = (1, YB, 1) + YB = 1 + else: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_reshape_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() + + output = torch.randint(1, (x.numel(),), dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x.reshape(-1) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_reshape_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_rsqrt.py b/third_party/ascend/examples/generalization_cases/test_general_rsqrt.py new file mode 100644 index 000000000..a66714452 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_rsqrt.py @@ -0,0 +1,137 @@ +import triton +import triton.language as tl +import torch +import logging +import pytest +import test_common +from test_common import TestUtils + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.rsqrt(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_rsqrt_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.rsqrt(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_rsqrt(dtype, shape, ): + x = test_common.generate_tensor(shape, dtype).abs().npu() + y = test_common.generate_tensor(shape, dtype).abs().npu() + z = test_common.generate_tensor(shape, dtype).abs().npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.rsqrt(x) + + if len(shape) == 1: + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, shape[0], 1, 1, shape[0]) + elif len(shape) == 2: + fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + mx = max(shape[0], shape[1], shape[2]) + if mx == shape[0]: + fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif mx == shape[1]: + fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + + test_common.validate_cmp(dtype, ans, output) + + + + + +invalid_dtypes = [ + 'int8', + 'int16', + 'int32', + 'uint32', + 'int64', + 'bool', +] + + +@pytest.mark.parametrize("dtype", invalid_dtypes) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") +def test_rsqrt_invalid_dtype_case(dtype): + x = test_common.generate_tensor((1,), dtype).npu() + y = test_common.generate_tensor((1,), dtype).npu() + z = test_common.generate_tensor((1,), dtype).npu() + + output = torch.randint(1, (1,), dtype=eval('torch.' + dtype)).npu() + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_rsqrt_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.rsqrt(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_rsqrt_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_sigmoid.py b/third_party/ascend/examples/generalization_cases/test_general_sigmoid.py new file mode 100644 index 000000000..858982951 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_sigmoid.py @@ -0,0 +1,157 @@ +import triton +import triton.language as tl +import torch +import logging +import pytest +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.sigmoid(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_sigmoid_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.sigmoid(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_sigmoid(dtype, shape, ): + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + if (x.dtype == torch.bfloat16): + ans = torch.sigmoid(x.to(torch.float32)).to(torch.bfloat16) + else: + ans = torch.sigmoid(x) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +invalid_dtypes = [ + 'int8', + 'int16', + 'int32', + 'uint32', + 'int64', + 'bool', +] + + +@pytest.mark.parametrize("dtype", invalid_dtypes) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") +def test_sigmoid_invalid_dtype_case(dtype): + x = test_common.generate_tensor((1,), dtype).npu() + y = test_common.generate_tensor((1,), dtype).npu() + z = test_common.generate_tensor((1,), dtype).npu() + + output = torch.randint(1, (1,), dtype=eval('torch.' + dtype)).npu() + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_sigmoid_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + if (x.dtype == torch.bfloat16): + ans = torch.sigmoid(x.to(torch.float32)).to(torch.bfloat16) + else: + ans = torch.sigmoid(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_sigmoid_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_sin.py b/third_party/ascend/examples/generalization_cases/test_general_sin.py new file mode 100644 index 000000000..d18add555 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_sin.py @@ -0,0 +1,157 @@ +import triton +import triton.language as tl +import torch +import numpy as np +import pytest +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.sin(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_sin_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.sin(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +import logging + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_sin(dtype, shape, ): + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.sin(x) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +invalid_dtypes = [ + 'int8', + 'int16', + 'int32', + 'uint32', + 'int64', + 'bool', +] + + +@pytest.mark.parametrize("dtype", invalid_dtypes) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") +def test_sin_invalid_dtype_case(dtype): + x = test_common.generate_tensor((1,), dtype).npu() + y = test_common.generate_tensor((1,), dtype).npu() + z = test_common.generate_tensor((1,), dtype).npu() + + output = torch.randint(1, (1,), dtype=eval('torch.' + dtype)).npu() + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_sin_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.sin(x) + + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_sin_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_softmax.py b/third_party/ascend/examples/generalization_cases/test_general_softmax.py new file mode 100644 index 000000000..494a03ad4 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_softmax.py @@ -0,0 +1,166 @@ +# 实际实现与官网定义不符,可能和triton submodule版本有关, 当前的submodule 不接受指定dim,都是按第0维做softmax +# arith.maximum 不支持类似 1x3 -> 3 和 1 -> 1 的reduce +import triton +import triton.language as tl +import torch +import logging +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_softmax_d0(x1): + res = torch.softmax(x1, axis=0).to(x1.dtype) + return res + + +@triton.jit +def tt_softmax_1d(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + idx = tl.arange(0, XB) + x = tl.load(in_ptr + idx) + ret = tl.softmax(x) + tl.store(out_ptr + idx, ret) + + +@triton.jit +def tt_softmax_2d(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + idx = xidx[:, None] * ynumel + yidx[None, :] + + a = tl.load(in_ptr + idx) + ret = tl.softmax(a) + + tl.store(out_ptr + idx, ret) + + +@triton.jit +def tt_softmax_3d(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + a = tl.load(in_ptr + idx) + ret = tl.softmax(a) + + tl.store(out_ptr + idx, ret) + + +@triton.jit +def triton_softmax_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.softmax(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_softmax(dtype, shape): + logging.log(logging.DEBUG, f"shape = {shape}", flush=True) + torch.manual_seed(0) + x = torch.rand(shape, dtype=eval('torch.' + dtype), device="npu") * 10 + grid = (1, 1, 1) + + y_cal = torch.rand(shape, dtype=eval('torch.' + dtype), device="npu") + + y_ref = torch_softmax_d0(x) + if len(shape) == 1: + tt_softmax_1d[grid](x, y_cal, x.numel(), 1, 1, x.numel(), 1, 1) + elif len(shape) == 2: + xnumel, ynumel, znumel = shape + (1,) + XB, YB, ZB = xnumel, ynumel, znumel + if x.numel() * x.element_size() > 8192: + grid = (1, ynumel, 1) + YB = 1 + tt_softmax_2d[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB) + + elif len(shape) == 3: + mx = max(shape[1], shape[2]) + if mx == shape[1]: + tt_softmax_3d[1, shape[1], 1](x, y_cal, shape[0], shape[1], shape[2], shape[0], 1, shape[2]) + else: + tt_softmax_3d[1, 1, shape[2]](x, y_cal, shape[0], shape[1], shape[2], shape[0], shape[1], 1) + + test_common.validate_cmp(dtype, y_cal, y_ref) + + +invalid_types = [ + 'int8', + 'int16', + 'int32', + 'uint32', + 'int64', + 'bool', +] + + +@pytest.mark.parametrize("dtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") +def test_softmax_invalid_dtype_case(dtype): + x0 = test_common.generate_tensor((1,), dtype).npu() + + y_cal = torch.zeros((1,), dtype=eval('torch.' + dtype)).npu() + tt_softmax_1d[1, 1, 1](x0, y_cal, 0, 0, 0, 1, 0, 0) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_softmax_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_softmax_d0(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_softmax_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_general_split.py b/third_party/ascend/examples/generalization_cases/test_general_split.py new file mode 100644 index 000000000..b39bacfd5 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_split.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, output_ptr1, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL:tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx=tl.arange(0,XB) + xoffs + yidx=tl.arange(0,YB) + yoffs + zidx=tl.arange(0,ZB) + zoffs + + idx=xidx[:,None,None,None]*YNUMEL*ZNUMEL*2+yidx[None,:,None,None]*ZNUMEL*2+ \ + zidx[None,None,:,None]*2 + tl.arange(0,2)[None,None,None,:] + + X = tl.load(x_ptr+idx) + + xx, yy = tl.split(X) + + oidx=xidx[:,None,None]*YNUMEL*ZNUMEL+yidx[None,:,None]*ZNUMEL+zidx[None,None,:] + + tl.store(output_ptr + oidx, xx) + tl.store(output_ptr1 + oidx, yy) + +import logging + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_split(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.'+dtype)).npu() + y = torch.full(shape, 30, dtype=eval('torch.'+dtype)).npu() + xx = torch.stack((x, y), dim=-1) + + a, b = torch.split(xx, 1, dim=-1) + + if len(shape) == 1: + XB = 1;xnumel = 1 + YB = 1;ynumel = 1 + ZB = shape[0];znumel = shape[0] + elif len(shape) == 2: + XB = 1;xnumel = 1 + YB = shape[0]; ynumel = shape[0] + ZB = shape[1];znumel = shape[1] + else: + XB = shape[0];xnumel = shape[0] + YB = shape[1];ynumel = shape[1] + ZB = shape[2];znumel = shape[2] + + a = a.reshape(XB, YB, ZB) + b = b.reshape(XB, YB, ZB) + output = torch.randint(1, (XB,YB,ZB), dtype=eval('torch.'+dtype)).npu() + output1 = torch.randint(1, (XB,YB,ZB), dtype=eval('torch.'+dtype)).npu() + + grid = (1,1,1) + if x.numel()*x.element_size() >= 8192: + if xnumel > 1: + grid = (XB,1,1) + XB = 1 + elif ynumel > 1: + grid = (1,YB,1) + YB = 1 + else: + grid = (1,1,ZB) + ZB = 1 + + fn_npu_[grid](output, xx, output1, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, a, output) + test_common.validate_cmp(dtype, b, output1) + + +@triton.jit +def fn_npu_4_8d( + output_ptr, x_ptr, output_ptr1, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + WB: tl.constexpr, VB: tl.constexpr, UB: tl.constexpr, + TB: tl.constexpr, SB: tl.constexpr +): + + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + widx = tl.arange(0, WB) + vidx = tl.arange(0, VB) + uidx = tl.arange(0, UB) + tidx = tl.arange(0, TB) + sidx = tl.arange(0, SB) + + idx = ( + xidx[:, None, None, None, None, None, None, None, None] * + YB * ZB * WB * VB * UB * TB * SB * 2 + + yidx[None, :, None, None, None, None, None, None, None] * + ZB * WB * VB * UB * TB * SB * 2 + + zidx[None, None, :, None, None, None, None, None, None] * + WB * VB * UB * TB * SB * 2 + + widx[None, None, None, :, None, None, None, None, None] * + VB * UB * TB * SB * 2 + + vidx[None, None, None, None, :, None, None, None, None] * + UB * TB * SB * 2 + + uidx[None, None, None, None, None, :, None, None, None] * + TB * SB * 2 + + tidx[None, None, None, None, None, None, :, None, None] * + SB * 2 + + sidx[None, None, None, None, None, None, None, :, None] * 2 + + tl.arange(0, 2)[None, None, None, None, None, None, None, None, :] + ) + + X = tl.load(x_ptr + idx) + xx, yy = tl.split(X) + + oidx = ( + xidx[:, None, None, None, None, None, None, None] * + YB * ZB * WB * VB * UB * TB * SB + + yidx[None, :, None, None, None, None, None, None] * + ZB * WB * VB * UB * TB * SB + + zidx[None, None, :, None, None, None, None, None] * + WB * VB * UB * TB * SB + + widx[None, None, None, :, None, None, None, None] * + VB * UB * TB * SB + + vidx[None, None, None, None, :, None, None, None] * + UB * TB * SB + + uidx[None, None, None, None, None, :, None, None] * + TB * SB + + tidx[None, None, None, None, None, None, :, None] * + SB + + sidx[None, None, None, None, None, None, None, :] + ) + + tl.store(output_ptr + oidx, xx) + tl.store(output_ptr1 + oidx, yy) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape_4_8d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_split_4_8d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() + y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() + xx = torch.stack((x, y), dim=-1) + + a, b = torch.split(xx, 1, dim=-1) + + if len(shape) == 1: + XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, 1, 1, 1, shape[0] + elif len(shape) == 2: + XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, 1, 1, shape[0], shape[1] + elif len(shape) == 3: + XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, 1, shape[0], shape[1], shape[2] + elif len(shape) == 4: + XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, 1, shape[0], shape[1], shape[2], shape[3] + elif len(shape) == 5: + XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, 1, shape[0], shape[1], shape[2], shape[3], shape[4] + elif len(shape) == 6: + XB, YB, ZB, WB, VB, UB, TB, SB = 1, 1, shape[0], shape[1], shape[2], shape[3], shape[4], shape[5] + elif len(shape) == 7: + XB, YB, ZB, WB, VB, UB, TB, SB = 1, shape[0], shape[1], shape[2], shape[3], shape[4], shape[5], shape[6] + else: + XB, YB, ZB, WB, VB, UB, TB, SB = shape + + a = a.reshape(XB, YB, ZB, WB, VB, UB, TB, SB) + b = b.reshape(XB, YB, ZB, WB, VB, UB, TB, SB) + + output = torch.randint(1, (XB, YB, ZB, WB, VB, UB, TB, SB), dtype=eval('torch.' + dtype)).npu() + output1 = torch.randint(1, (XB, YB, ZB, WB, VB, UB, TB, SB), dtype=eval('torch.' + dtype)).npu() + + grid = (1, 1, 1) + fn_npu_4_8d[grid](output, xx, output1, XB, YB, ZB, WB, VB, UB, TB, SB) + + test_common.validate_cmp(dtype, a, output) + test_common.validate_cmp(dtype, b, output1) + + diff --git a/third_party/ascend/examples/generalization_cases/test_general_sub.py b/third_party/ascend/examples/generalization_cases/test_general_sub.py new file mode 100644 index 000000000..fe2e1a186 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_sub.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# Only floating point clamp is supported +import pytest + +import triton +import triton.language as tl +import torch +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def triton_sub(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X - Y + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_sub_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val - y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_sub(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + + ans = x - y + output = torch.zeros_like(ans) + + if len(shape) == 1: + triton_sub[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + triton_sub[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + triton_sub[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + triton_sub[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + triton_sub[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + triton_sub[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + triton_sub[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_sub_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x - y + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_sub_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_general_tensor_descriptor.py b/third_party/ascend/examples/generalization_cases/test_general_tensor_descriptor.py new file mode 100644 index 000000000..46139e3f1 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_tensor_descriptor.py @@ -0,0 +1,213 @@ +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils + + +@triton.jit +def triton_tensor_descriptor_2d( + out_ptr, x_ptr, + M: tl.constexpr, N: tl.constexpr, + M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, +): + in_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + block = in_desc.load([moffset, noffset]) + out_desc.store([moffset, noffset], block) + + +@triton.jit +def triton_tensor_descriptor_3d( + out_ptr, x_ptr, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + stride_m: tl.constexpr, stride_n: tl.constexpr, stride_k: tl.constexpr, + M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, K_BLOCK: tl.constexpr, +): + in_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[M, N, K], + strides=[stride_m, stride_n, stride_k], + block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + shape=[M, N, K], + strides=[stride_m, stride_n, stride_k], + block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], + ) + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + koffset = tl.program_id(2) * K_BLOCK + block = in_desc.load([moffset, noffset, koffset]) + out_desc.store([moffset, noffset, koffset], block) + + +@triton.jit +def triton_tensor_descriptor_4d( + out_ptr, x_ptr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, + SHAPE_3: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, + BLOCK_3: tl.constexpr, +): + pid0 = tl.program_id(0) + pid1 = tl.program_id(1) + pid2 = tl.program_id(2) + idx2 = pid2 // BLOCK_3 + idx3 = pid2 % BLOCK_3 + o1 = pid0 * BLOCK_0 + o2 = pid1 * BLOCK_1 + o3 = idx2 * BLOCK_2 + o4 = idx3 * BLOCK_3 + in_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3], + strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3], + block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3], + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3], + strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3], + block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3], + ) + block = in_desc.load([o1, o2, o3, o4]) + out_desc.store([o1, o2, o3, o4], block) + + +@triton.jit +def triton_tensor_descriptor_5d( + out_ptr, x_ptr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, + SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, + BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, +): + pid0 = tl.program_id(0) + pid1 = tl.program_id(1) + pid2 = tl.program_id(2) + idx3 = pid2 // (BLOCK_3 * BLOCK_4) + idx4 = (pid2 // BLOCK_4) % BLOCK_3 + idx5 = pid2 % BLOCK_4 + o1 = pid0 * BLOCK_0 + o2 = pid1 * BLOCK_1 + o3 = idx3 * BLOCK_2 + o4 = idx4 * BLOCK_3 + o5 = idx5 * BLOCK_4 + in_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4], + strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4], + block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4], + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + shape=[SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4], + strides=[STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4], + block_shape=[BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4], + ) + block = in_desc.load([o1, o2, o3, o4, o5]) + out_desc.store([o1, o2, o3, o4, o5], block) + + +@triton.jit +def triton_tensor_descriptor_function_2d( + out_ptr, x_ptr, + M: tl.constexpr, N: tl.constexpr, + M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, +): + in_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + block = tl.load_tensor_descriptor(in_desc, [moffset, noffset]) + tl.store_tensor_descriptor(out_desc, [moffset, noffset], block) + + +temporarily_not_support_dtype = ['bool'] + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.full_shape) +def test_tensor_descriptor_load_store_nd(dtype, shape): + """测试tensor_descriptor的load和store功能""" + + if dtype in temporarily_not_support_dtype: + pytest.skip(f"{dtype} not supported") + + inp = test_common.generate_tensor(shape, dtype).npu() + out = inp.new_empty(shape) + blocks = list(inp.size()) + strides = list(inp.stride()) + grid = (1,) + dims = len(shape) + + # 如果最后一维小于16字节,则跳过 + itemsize = torch.tensor([], dtype=inp.dtype).element_size() + if blocks[-1] * itemsize < 16: + pytest.skip(f"last dimension must be at least 16 bytes, but got {blocks[-1] * itemsize} bytes") + + if dims == 2: + if inp.numel() * inp.element_size() > 8192: + triton_tensor_descriptor_2d[shape[0], 1, 1](out, inp, 1, shape[1], 1, shape[1]) + else: + triton_tensor_descriptor_2d[grid](out, inp, *shape, *blocks) + torch.testing.assert_close(inp, out) + elif dims == 3: + triton_tensor_descriptor_3d[grid](out, inp, *shape, *strides, *blocks) + torch.testing.assert_close(inp, out) + elif dims == 4: + triton_tensor_descriptor_4d[grid](out, inp, *shape, *strides, *blocks) + torch.testing.assert_close(inp, out) + elif dims == 5: + triton_tensor_descriptor_5d[grid](out, inp, *shape, *strides, *blocks) + torch.testing.assert_close(inp, out) + else: + pytest.skip(f"{dims}d not supported") + + +@pytest.mark.parametrize("dtype", ["float32"]) +def test_tensor_descriptor_in_function(dtype): + """测试函数式接口是否正常工作""" + + M, N = 32, 128 + inp = test_common.generate_tensor((M, N), dtype).npu() + out = inp.new_empty((M, N)) + + M_BLOCK = 8 + N_BLOCK = 32 + grid_m = M // M_BLOCK + grid_n = N // N_BLOCK + + triton_tensor_descriptor_function_2d[(grid_m, grid_n)](out, inp, M, N, M_BLOCK, N_BLOCK) + torch.testing.assert_close(inp, out) diff --git a/third_party/ascend/examples/generalization_cases/test_general_view.py b/third_party/ascend/examples/generalization_cases/test_general_view.py new file mode 100644 index 000000000..ed8a212ad --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_general_view.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import logging +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.view(X, (ZB * YB * XB,)) + + oidx = tl.arange(0, XB * YB * ZB) + xoffs * YNUMEL * ZNUMEL + yoffs * ZNUMEL + zoffs + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def triton_view_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.view(x_val, (SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4,)) + + pid0 = tl.program_id(0) + + flat_idx = tl.arange(0, BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) + out_offsets = pid0 * BLOCK_0 * BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4 + flat_idx + out_masks = out_offsets < SHAPE_0 * SHAPE_1 * SHAPE_2 * SHAPE_3 * SHAPE_4 + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_view(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() + y = torch.full(shape, 30, dtype=eval('torch.' + dtype)).npu() + new_shape = (x.numel(),) + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x.view(new_shape) + + if len(shape) == 1: + XB = 1; + xnumel = 1 + YB = 1; + ynumel = 1 + ZB = shape[0]; + znumel = shape[0] + elif len(shape) == 2: + XB = 1; + xnumel = 1 + YB = shape[0]; + ynumel = shape[0] + ZB = shape[1]; + znumel = shape[1] + else: + XB = shape[0]; + xnumel = shape[0] + YB = shape[1]; + ynumel = shape[1] + ZB = shape[2]; + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + if xnumel > 1: + grid = (XB, 1, 1) + XB = 1 + elif ynumel > 1: + grid = (1, YB, 1) + YB = 1 + else: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_view_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.full(shape, 100, dtype=eval('torch.' + dtype)).npu() + + output = torch.randint(1, (x.numel(),), dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x.view(x.numel(), ) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_view_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_gt_op.py b/third_party/ascend/examples/generalization_cases/test_gt_op.py new file mode 100644 index 000000000..32c3f7aec --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_gt_op.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import logging +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils + + +@triton.jit +def triton_gt_3d(in_ptr0, in_ptr1, out_ptr0, L : tl.constexpr, M : tl.constexpr, N : tl.constexpr): + lblk_idx = tl.arange(0,L) + mblk_idx = tl.arange(0,M) + nblk_idx = tl.arange(0,N) + idx = lblk_idx[:,None,None]*N*M+mblk_idx[None,:,None]*N+nblk_idx[None,None,:] + x0=tl.load(in_ptr0+idx) + x1=tl.load(in_ptr1+idx) + ret = x0 > x1 + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0+odx, ret) + + +@triton.jit +def triton_gt_2d(in_ptr0, in_ptr1, out_ptr0, M : tl.constexpr, N : tl.constexpr): + moffs = tl.program_id(0) * M + mblk_idx = tl.arange(0,M) + moffs + nblk_idx = tl.arange(0,N) + idx = mblk_idx[:,None]*N+nblk_idx[None,:] + x0=tl.load(in_ptr0+idx) + x1=tl.load(in_ptr1+idx) + ret = x0 > x1 + odx = mblk_idx[:,None]*N+nblk_idx[None,:] + tl.store(out_ptr0+odx, ret) + + +@triton.jit +def triton_gt_1d(in_ptr0, in_ptr1, out_ptr0, L : tl.constexpr): + lblk_idx = tl.arange(0,L) + idx = lblk_idx[:] + x0=tl.load(in_ptr0+idx) + x1=tl.load(in_ptr1+idx) + ret = x0 > x1 + odx = lblk_idx[:] + tl.store(out_ptr0+odx, ret) + + +@triton.jit +def triton_gt_4d_5d( + x_ptr, y_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val > y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] + +dtype_mapping = { + 'int8': (torch.int8), + 'int16': (torch.int16), + 'int32': (torch.int32), + 'uint32': (torch.uint32), + 'int64': (torch.int64), + 'float16': (torch.float16), + 'float32': (torch.float32), + 'bfloat16': (torch.bfloat16), + 'bool': (torch.bool), +} + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('sigtype',typelist) +def test_gt(sigtype, shape): + dtype = dtype_mapping[sigtype] + x0 = test_common.generate_tensor(shape = shape,dtype = sigtype).npu() + x1 = test_common.generate_tensor(shape = shape,dtype = sigtype).npu() + # ncore, xblock, xblock_sub = 2, 32768, 1024 + y_ref = torch.where(torch.gt(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) + output = torch.zeros(shape, dtype=dtype).npu() + if len(shape) == 3: + triton_gt_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) + if len(shape) == 2: + shape0 = shape[0] + shape1 = shape[1] + if x0.numel() * x0.element_size() >= 8192: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_gt_2d[grid](x0, x1, output, shape0, shape1) + if len(shape) == 1: + triton_gt_1d[1, 1, 1](x0, x1, output, shape[0]) + test_common.validate_cmp(sigtype, output, y_ref) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +def test_gt_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.where(torch.gt(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_gt_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_invalid_fp8.py b/third_party/ascend/examples/generalization_cases/test_invalid_fp8.py new file mode 100644 index 000000000..d469e6e47 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_invalid_fp8.py @@ -0,0 +1,37 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + + +@triton.jit +def triton_test_fp8(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp0 = tmp0.to(tl.float8e5) + tmp1 = tmp1.to(tl.float8e5) + tmp2 = tmp0 + tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ] + ) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type fp8") +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_test_fp8[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_invert.py b/third_party/ascend/examples/generalization_cases/test_invert.py new file mode 100644 index 000000000..309405331 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_invert.py @@ -0,0 +1,152 @@ +import logging + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_invert(x0, ddtype): + if 'float' in str(ddtype): + x0 = x0.to(torch.int32) + y_ref = ~x0 + y_ref = y_ref.to(ddtype) + else: + y_ref = ~x0 + return y_ref + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = ~X + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_invert_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = ~x_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ddtype = eval('torch.' + dtype) + ans = torch_invert(x, ddtype) + + if len(shape) == 1: + fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + +invalid_types = [ + 'float16', + 'float32', + 'bfloat16', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + z = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_invert_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_invert(x, eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_invert_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_le_op.py b/third_party/ascend/examples/generalization_cases/test_le_op.py new file mode 100644 index 000000000..9b2a234e5 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_le_op.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import logging +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils + + +@triton.jit +def triton_le_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 <= x1 + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_le_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): + moffs = tl.program_id(0) * M + mblk_idx = tl.arange(0, M) + moffs + nblk_idx = tl.arange(0, N) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 <= x1 + odx = mblk_idx[:, None] * N + nblk_idx[None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_le_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): + lblk_idx = tl.arange(0, L) + idx = lblk_idx[:] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 <= x1 + odx = lblk_idx[:] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_le_4d_5d( + x_ptr, y_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val <= y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] + +dtype_mapping = { + 'int8': (torch.int8), + 'int16': (torch.int16), + 'int32': (torch.int32), + 'uint32': (torch.uint32), + 'int64': (torch.int64), + 'float16': (torch.float16), + 'float32': (torch.float32), + 'bfloat16': (torch.bfloat16), + 'bool': (torch.bool), +} + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('sigtype', typelist) +def test_le(sigtype, shape): + dtype = dtype_mapping[sigtype] + x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() + x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() + # ncore, xblock, xblock_sub = 2, 32768, 1024 + y_ref = torch.where(torch.le(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) + output = torch.zeros(shape, dtype=dtype).npu() + if len(shape) == 3: + triton_le_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) + if len(shape) == 2: + shape0 = shape[0] + shape1 = shape[1] + if x0.numel() * x0.element_size() >= 8192: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_le_2d[grid](x0, x1, output, shape0, shape1) + if len(shape) == 1: + triton_le_1d[1, 1, 1](x0, x1, output, shape[0]) + test_common.validate_cmp(sigtype, output, y_ref) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +def test_le_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.where(torch.le(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_le_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_load_store.py b/third_party/ascend/examples/generalization_cases/test_load_store.py new file mode 100644 index 000000000..9384b0ec7 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_load_store.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def fn_npu_1d(output_ptr, x_ptr, YB: tl.constexpr): + idx = tl.arange(0, YB) + X = tl.load(x_ptr + idx) + tl.store(output_ptr + idx, X) + +def torch_fn_npu_1d(x): + return x + +@triton.jit +def fn_npu_2d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): + pid = tl.program_id(0) + y_idx = tl.arange(0, YB)[:, None] + pid * YB + z_idx = tl.arange(0, ZB)[None, :] + idx = y_idx * ZB + z_idx + + X = tl.load(x_ptr + idx) + + tl.store(output_ptr + idx, X) + +def torch_fn_npu_2d(x): + return x + + +@triton.jit +def fn_npu_3d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + y = tl.arange(0, YB)[:, None, None] + z = tl.arange(0, ZB)[None, :, None] + k = tl.arange(0, KB)[None, None, :] + + idx = y * ZB * KB + z * KB + k + + X = tl.load(x_ptr + idx) + + tl.store(output_ptr + idx, X) + +def torch_fn_npu_3d(x): + return x + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('dtype', TestUtils.dtype_list) +def test_npu(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + data_type = eval('torch.' + dtype) + x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() + triton_res = torch.empty(shape, dtype=data_type).npu() + torch_res = x + if len(shape) == 1: + torch_res = torch_fn_npu_1d(x) + fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) + # uint32 转成 float32算精度,因为torch_npu不支持uint32类型张量的slice + torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) + triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) + cmp_type = dtype if dtype != 'uint32' else 'float32' + test_common.validate_cmp(cmp_type, triton_res[:2 * shape[0] // 3], torch_res[:2 * shape[0] // 3]) + elif len(shape) == 2: + torch_res = torch_fn_npu_2d(x) + fn_npu_2d[shape[0], 1, 1](triton_res, x, 1, shape[1]) + torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) + triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) + cmp_type = dtype if dtype != 'uint32' else 'float32' + test_common.validate_cmp(cmp_type, triton_res[:2 * shape[0] // 3, :2 * shape[1] // 3], + torch_res[:2 * shape[0] // 3, :2 * shape[1] // 3]) + elif len(shape) == 3: + torch_res = torch_fn_npu_3d(x) + fn_npu_3d[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) + triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) + cmp_type = dtype if dtype != 'uint32' else 'float32' + test_common.validate_cmp(cmp_type, triton_res[:2 * shape[0] // 3, :2 * shape[1] // 3, :2 * shape[2] // 3], + torch_res[:2 * shape[0] // 3, :2 * shape[1] // 3, :2 * shape[2] // 3]) + + +# require: all data (4d and 5d) can be placed into but without ub overflow +@triton.jit +def triton_load_store_multi_d( + in_ptr0, out_ptr0, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + tmp_in = tl.load(in_ptr0 + offsets, masks) + tmp_out = tmp_in + tl.store(out_ptr0 + offsets, tmp_out, masks) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('param_list', + [ + ['float32', (8, 4, 16, 16)], + ['float16', (8, 4, 16, 16)], + ['int8', (8, 4, 16, 16)], + ['float32', (8, 8, 4, 4)], + ['float16', (8, 8, 4, 4)], + ['int8', (8, 8, 4, 4)], + ['float32', (3, 8, 2, 16, 16)], + ['float16', (3, 8, 2, 16, 16)], + ['int8', (9, 8, 8, 16, 16)], + ['float32', (11, 8, 8, 4, 4)], + ['float16', (11, 8, 8, 4, 4)], + ['int8', (11, 8, 8, 4, 4)], + ] + ) +def test_load_store_4d_5d(param_list): + # 生成数据 + dtype, shape = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + y_expect = x0 + y_actual = test_common.generate_tensor(shape, dtype).npu() + # triton结果 + blocks = list(x0.size()) + shapes = list(x0.stride()) + while len(blocks) < 5: + blocks.append(1) + shapes.append(1) + triton_load_store_multi_d[(1, )](x0, y_actual, *blocks, *blocks, *shapes) + # 比较结果 + test_common.validate_cmp(dtype, y_actual, y_expect) diff --git a/third_party/ascend/examples/generalization_cases/test_log1p.py b/third_party/ascend/examples/generalization_cases/test_log1p.py new file mode 100644 index 000000000..0c7b2af93 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_log1p.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math +import triton.language.extra.ascend.libdevice as libdevice +def torch_pointwise(x0): + res = torch.log1p(x0) + return res + +@triton.jit +def triton_log1p(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp2 = libdevice.log1p(tmp0) + tl.store(out_ptr0 + (x0), tmp2, None) + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['float32', 'float16']) +def test_case(dtype, shape): + x0 = test_common.generate_tensor(shape, dtype).npu() + + numel = x0.numel() + ncore = 1 if numel <= 32 else 32 + xblock = math.ceil(numel / ncore) + xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) + + y_ref = torch_pointwise(x0) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_log1p[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/generalization_cases/test_logical_and_op.py b/third_party/ascend/examples/generalization_cases/test_logical_and_op.py new file mode 100644 index 000000000..cdaf6ba51 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_logical_and_op.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils, generate_tensor +import logging + + +@triton.jit +def triton_logical_and_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): + lblk_idx = tl.arange(0, L) + idx = lblk_idx + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0.logical_and(x1) + odx = lblk_idx + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_logical_and_2d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr): + loffs = tl.program_id(0) * L + lblk_idx = tl.arange(0, L) + loffs + mblk_idx = tl.arange(0, M) + idx = lblk_idx[:, None] * M + mblk_idx[None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0.logical_and(x1) + odx = lblk_idx[:, None] * M + mblk_idx[None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_logical_and_3d(in_ptr0, in_ptr1, out_ptr0, XB, YB, ZB, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + tl.program_id(0) * XB + mblk_idx = tl.arange(0, M) + tl.program_id(1) * YB + nblk_idx = tl.arange(0, N) + tl.program_id(2) * ZB + idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0.logical_and(x1) + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_logical_and_4d_5d( + x_ptr, y_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val.logical_and(y_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +support_typelist = ['bool', ] + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('sigtype', support_typelist) +def test_logical_and(shape, sigtype): + logging.debug(f"dtype:{sigtype} shape:{shape}") + dtype = eval('torch.' + sigtype) + x0 = generate_tensor(shape=shape, dtype=sigtype).npu() + x1 = generate_tensor(shape=shape, dtype=sigtype).npu() + # ncore, xblock, xblock_sub = 2, 32768, 1024 + y_ref = torch.logical_and(x0, x1) + output = torch.zeros(shape, dtype=dtype).npu() + if len(shape) == 1: + triton_logical_and_1d[1, 1, 1](x0, x1, output, shape[0]) + elif len(shape) == 2: + shape0 = shape[0] + shape1 = shape[1] + if x0.numel() * x0.element_size() >= 8192: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_logical_and_2d[grid](x0, x1, output, shape0, shape1) + elif len(shape) == 3: + mx = max(shape[0], shape[1], shape[2]) + if mx == shape[0]: + triton_logical_and_3d[shape[0], 1, 1](x0, x1, output, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif mx == shape[1]: + triton_logical_and_3d[1, shape[1], 1](x0, x1, output, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + triton_logical_and_3d[1, 1, shape[2]](x0, x1, output, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + + test_common.validate_cmp(sigtype, output, y_ref) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['bool']) +def test_logical_and_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.logical_and(x, y) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_logical_and_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_logical_or_op.py b/third_party/ascend/examples/generalization_cases/test_logical_or_op.py new file mode 100644 index 000000000..65b9550e8 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_logical_or_op.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils, generate_tensor +import logging + + +@triton.jit +def triton_logical_or_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): + lblk_idx = tl.arange(0, L) + idx = lblk_idx + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0.logical_or(x1) + odx = lblk_idx + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_logical_or_2d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr): + pid = tl.program_id(0) + lblk_idx = tl.arange(0, L) + pid * L + mblk_idx = tl.arange(0, M) + idx = lblk_idx[:, None] * M + mblk_idx[None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0.logical_or(x1) + odx = lblk_idx[:, None] * M + mblk_idx[None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_logical_or_3d(in_ptr0, in_ptr1, out_ptr0, XB, YB, ZB, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + tl.program_id(0) * XB + mblk_idx = tl.arange(0, M) + tl.program_id(1) * YB + nblk_idx = tl.arange(0, N) + tl.program_id(2) * ZB + idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0.logical_or(x1) + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_logical_or_4d_5d( + x_ptr, y_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val.logical_or(y_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +support_typelist = ['bool', ] + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('sigtype', support_typelist) +def test_logical_or(shape, sigtype): + logging.debug(f"dtype:{sigtype} shape:{shape}") + dtype = eval('torch.' + sigtype) + x0 = generate_tensor(shape=shape, dtype=sigtype).npu() + x1 = generate_tensor(shape=shape, dtype=sigtype).npu() + # ncore, xblock, xblock_sub = 2, 32768, 1024 + y_ref = torch.logical_or(x0, x1) + output = torch.zeros(shape, dtype=dtype).npu() + if len(shape) == 1: + triton_logical_or_1d[1, 1, 1](x0, x1, output, shape[0]) + elif len(shape) == 2: + shape0 = shape[0] + shape1 = shape[1] + if x0.numel() * x0.element_size() >= 8192: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_logical_or_2d[grid](x0, x1, output, shape0, shape1) + elif len(shape) == 3: + mx = max(shape[0], shape[1], shape[2]) + if mx == shape[0]: + triton_logical_or_3d[shape[0], 1, 1](x0, x1, output, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif mx == shape[1]: + triton_logical_or_3d[1, shape[1], 1](x0, x1, output, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + triton_logical_or_3d[1, 1, shape[2]](x0, x1, output, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + test_common.validate_cmp(sigtype, output, y_ref) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['bool']) +def test_logical_or_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.logical_or(x, y) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_logical_or_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_lshift_op.py b/third_party/ascend/examples/generalization_cases/test_lshift_op.py new file mode 100644 index 000000000..2c7d48847 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_lshift_op.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import logging +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils + +@triton.jit +def triton_lshift_1d(in_ptr0, out_ptr0, L : tl.constexpr): + lblk_idx = tl.arange(0,L) + idx = lblk_idx[:] + x0=tl.load(in_ptr0+idx) + ret = x0 << 2 + odx = lblk_idx[:] + tl.store(out_ptr0+odx, ret) + + +@triton.jit +def triton_lshift_2d(in_ptr0, out_ptr0, M : tl.constexpr, N : tl.constexpr): + moffs = tl.program_id(0) * M + mblk_idx = tl.arange(0,M) + moffs + nblk_idx = tl.arange(0,N) + idx = mblk_idx[:,None]*N+nblk_idx[None,:] + x0=tl.load(in_ptr0+idx) + ret = x0 << 2 + odx = mblk_idx[:,None]*N+nblk_idx[None,:] + tl.store(out_ptr0+odx, ret) + + +@triton.jit +def triton_lshift_3d(in_ptr0, out_ptr0, L : tl.constexpr, M : tl.constexpr, N : tl.constexpr): + loffs = tl.program_id(0) * L + lblk_idx = tl.arange(0,L) + loffs + mblk_idx = tl.arange(0,M) + nblk_idx = tl.arange(0,N) + idx = lblk_idx[:,None,None]*N*M+mblk_idx[None,:,None]*N+nblk_idx[None,None,:] + x0=tl.load(in_ptr0+idx) + ret = x0 << 2 + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0+odx, ret) + + +@triton.jit +def triton_lshift_4d_5d( + x_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = x_val << 2 + tl.store(output_ptr + offsets, ret, mask=masks) + + +dtype_mapping = { + 'int8': (torch.int8), + 'int16': (torch.int16), + 'int32': (torch.int32), + 'uint32': (torch.uint32), + 'int64': (torch.int64), + 'float16': (torch.float16), + 'float32': (torch.float32), + 'bfloat16': (torch.bfloat16), + 'bool': (torch.bool), +} + +typelist = ['int8','int16','int32','int64',] + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('sigtype',typelist) +def test_lshift(sigtype, shape): + dtype = dtype_mapping[sigtype] + x0 = test_common.generate_tensor(shape = shape, dtype = sigtype).npu() + # ncore, xblock, xblock_sub = 2, 32768, 1024 + y_ref = x0 << 2 + output = torch.zeros(shape, dtype=dtype).npu() + if len(shape) == 3: + shape0 = shape[0] + shape1 = shape[1] + shape2 = shape[2] + if x0.numel() * x0.element_size() >= 1024: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_lshift_3d[grid](x0, output, shape0, shape1, shape2) + if len(shape) == 2: + shape0 = shape[0] + shape1 = shape[1] + if x0.numel() * x0.element_size() >= 1024: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_lshift_2d[grid](x0, output, shape0, shape1) + if len(shape) == 1: + triton_lshift_1d[1, 1, 1](x0, output, shape[0]) + test_common.validate_cmp(sigtype, output, y_ref) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) +def test_lshift_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x << 2 + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_lshift_4d_5d[grid](x, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) + +invalid_types = [ + 'float16', + 'float32', + 'bfloat16', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + triton_lshift_1d[1, 1, 1](x, output, N) + diff --git a/third_party/ascend/examples/generalization_cases/test_lt_op.py b/third_party/ascend/examples/generalization_cases/test_lt_op.py new file mode 100644 index 000000000..12ba50247 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_lt_op.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils +import logging + + +@triton.jit +def triton_lt_3d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 < x1 + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_lt_2d(in_ptr0, in_ptr1, out_ptr0, M: tl.constexpr, N: tl.constexpr): + moffs = tl.program_id(0) * M + mblk_idx = tl.arange(0, M) + moffs + nblk_idx = tl.arange(0, N) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 < x1 + odx = mblk_idx[:, None] * N + nblk_idx[None, :] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_lt_1d(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr): + lblk_idx = tl.arange(0, L) + idx = lblk_idx[:] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 < x1 + odx = lblk_idx[:] + tl.store(out_ptr0 + odx, ret) + + +@triton.jit +def triton_lt_4d_5d( + x_ptr, y_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val < y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +typelist = ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32'] + +dtype_mapping = { + 'int8': (torch.int8), + 'int16': (torch.int16), + 'int32': (torch.int32), + 'uint32': (torch.uint32), + 'int64': (torch.int64), + 'float16': (torch.float16), + 'float32': (torch.float32), + 'bfloat16': (torch.bfloat16), + 'bool': (torch.bool), +} + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('sigtype', typelist) +def test_lt(sigtype, shape): + dtype = dtype_mapping[sigtype] + x0 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() + x1 = test_common.generate_tensor(shape=shape, dtype=sigtype).npu() + # ncore, xblock, xblock_sub = 2, 32768, 1024 + y_ref = torch.where(torch.lt(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + sigtype)) + output = torch.zeros(shape, dtype=dtype).npu() + if len(shape) == 3: + triton_lt_3d[1, 1, 1](x0, x1, output, shape[0], shape[1], shape[2]) + if len(shape) == 2: + shape0 = shape[0] + shape1 = shape[1] + if x0.numel() * x0.element_size() >= 8192: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_lt_2d[grid](x0, x1, output, shape0, shape1) + if len(shape) == 1: + triton_lt_1d[1, 1, 1](x0, x1, output, shape[0]) + test_common.validate_cmp(sigtype, output, y_ref) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +def test_lt_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.where(torch.lt(x, y), torch.ones_like(x), torch.zeros_like(x)).to(eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_lt_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_make_blkptr_matmul.py b/third_party/ascend/examples/generalization_cases/test_make_blkptr_matmul.py new file mode 100644 index 000000000..6860ed614 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_make_blkptr_matmul.py @@ -0,0 +1,64 @@ +import logging +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +import test_common +from test_common import TestUtils, avoid_not_support, get_dtype_size + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + acc_dtype: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + matxa_ptr_in = tl.make_block_ptr(a_ptr, + (M, K), + (K, 1), + (0, 0), + (M, K), + order=(1, 0)) + matxb_ptr_in = tl.make_block_ptr(b_ptr, + (K, N), + (N, 1), + (0, 0), + (K, N), + order=(1, 0)) + matxc_ptr_in = tl.make_block_ptr(c_ptr, + (M, N), + (N, 1), + (0, 0), + (M, N), + order=(1, 0)) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + a = tl.load(matxa_ptr_in) + b = tl.load(matxb_ptr_in) + accumulator = tl.dot(a, b, accumulator, out_dtype=acc_dtype) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store(matxc_ptr_in, c) + + +@avoid_not_support('matmul') +@pytest.mark.parametrize('shape', [(16, 32)]) +@pytest.mark.parametrize('dtype', ['float32']) +def test_matmul(shape, dtype): + M, N, K = shape[0], shape[0], shape[1] + + BLOCK_M, BLOCK_N, BLOCK_K = M, N, K + a = test_common.generate_tensor((M, K), dtype) + b = test_common.generate_tensor((K, N), dtype) + + triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() + accumulator_type = tl.float32 + + matmul_kernel[1, ](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type, + BLOCK_M, BLOCK_N, BLOCK_K, enable_nd2nz_on_vector=False) + + print("PASSED") diff --git a/third_party/ascend/examples/generalization_cases/test_make_block_ptr.py b/third_party/ascend/examples/generalization_cases/test_make_block_ptr.py new file mode 100644 index 000000000..d71370fe8 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_make_block_ptr.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils + + +@triton.jit +def fn_npu_1d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(XB,), + strides=(1,), + offsets=(0,), + block_shape=(XB,), + order=(0,), + ) + X = tl.load(block_ptr_in) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(XB,), + strides=(1,), + offsets=(0,), + block_shape=(XB,), + order=(0,), + ) + tl.store(block_ptr_out, X) + + +@triton.jit +def fn_npu_2d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xoffset = tl.program_id(0) + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(XB, YB), + strides=(YB, 1), + offsets=(xoffset, 0), + block_shape=(XB, YB), + order=(1, 0), + ) + X = tl.load(block_ptr_in) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(XB, YB), + strides=(YB, 1), + offsets=(xoffset, 0), + block_shape=(XB, YB), + order=(1, 0), + ) + tl.store(block_ptr_out, X) + + +@triton.jit +def fn_npu_3d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(XB, YB, ZB), + strides=(YB * ZB, ZB, 1), + offsets=(0, 0, 0), + block_shape=(XB, YB, ZB), + order=(2, 1, 0), + ) + X = tl.load(block_ptr_in) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(XB, YB, ZB), + strides=(YB * ZB, ZB, 1), + offsets=(0, 0, 0), + block_shape=(XB, YB, ZB), + order=(2, 1, 0), + ) + tl.store(block_ptr_out, X) + + +@triton.jit +def triton_make_block_ptr_4d(output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, ): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), + strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), + offsets=(0, 0, 0, 0), + block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), + order=(3, 2, 1, 0), + ) + x = tl.load(block_ptr_in) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3), + strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3), + offsets=(0, 0, 0, 0), + block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3), + order=(3, 2, 1, 0), + ) + tl.store(block_ptr_out, x) + + +@triton.jit +def triton_make_block_ptr_5d(output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr, ): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), + strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), + offsets=(0, 0, 0, 0, 0), + block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), + order=(4, 3, 2, 1, 0), + ) + x = tl.load(block_ptr_in) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(SHAPE_0, SHAPE_1, SHAPE_2, SHAPE_3, SHAPE_4), + strides=(STRIDE_0, STRIDE_1, STRIDE_2, STRIDE_3, STRIDE_4), + offsets=(0, 0, 0, 0, 0), + block_shape=(BLOCK_0, BLOCK_1, BLOCK_2, BLOCK_3, BLOCK_4), + order=(4, 3, 2, 1, 0), + ) + tl.store(block_ptr_out, x) + + +temporarily_not_support_dtype = ['bool'] + + +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('shape', TestUtils.full_shape) +def test_npu(dtype, shape): + if dtype in temporarily_not_support_dtype: + return + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + + a = x + blocks = list(x.size()) + strides = list(x.stride()) + grid = (1,) + if len(shape) == 5: + triton_make_block_ptr_5d[grid](output, x, *blocks, *blocks, *strides) + elif len(shape) == 4: + triton_make_block_ptr_4d[grid](output, x, *blocks, *blocks, *strides) + elif len(shape) == 3: + fn_npu_3d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=shape[2]) + elif len(shape) == 2: + if x.numel() * x.element_size() > 8192: + fn_npu_2d[shape[0], 1, 1](output, x, y, z, output1, XB=1, YB=shape[1], ZB=1) + else: + fn_npu_2d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=1) + else: + fn_npu_1d[1, 1, 1](output, x, y, z, output1, XB=shape[0], YB=1, ZB=1) + torch.testing.assert_close(output, a) diff --git a/third_party/ascend/examples/generalization_cases/test_matmul.py b/third_party/ascend/examples/generalization_cases/test_matmul.py new file mode 100644 index 000000000..55691251a --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_matmul.py @@ -0,0 +1,164 @@ +import logging +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +import acc_util +import test_common +from test_common import TestUtils, avoid_not_support, get_dtype_size + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + acc_dtype: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + accumulator = tl.dot(a, b, accumulator, out_dtype=acc_dtype) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + c = accumulator.to(c_ptr.dtype.element_ty) + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + +@avoid_not_support('matmul') +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', TestUtils.dtype_list) +def test_matmul(shape, dtype): + M, N, K = shape[0], shape[0], shape[1] + # 32byte/Dtype_bytes + kalign = 32 // get_dtype_size(dtype) + BLOCK_M, BLOCK_N, BLOCK_K = min(max(M, 16), 32), min(max(N, 16), 32), min(max(K, kalign), 32) + a = test_common.generate_tensor((M, K), dtype) + b = test_common.generate_tensor((K, N), dtype) + + if dtype == "int8": + triton_res = torch.zeros((M, N), dtype=torch.int32).npu() + accumulator_type = tl.int32 + else: + triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() + accumulator_type = tl.float32 + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + + matmul_kernel[grid](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type, + a.stride(0), a.stride(1), b.stride(0), b.stride(1), + triton_res.stride(0), triton_res.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K) + + a_gold = a.to(torch.float32) + b_gold = b.to(torch.float32) + cpu_res = torch.mm(a_gold, b_gold) + + if dtype == "int8": + # torch_npu do not support int8 matmul + a_npu = a.npu().to(torch.float32) + b_npu = b.npu().to(torch.float32) + torch_res = torch.mm(a_npu, b_npu) + triton_res = triton_res.to(torch.float32) + else: + a_npu = a.npu() + b_npu = b.npu() + torch_res = torch.mm(a_npu, b_npu) + + try: + print("starting compare of cpu vs triton:") + acc_util.assert_close(cpu_res, triton_res) + except Exception as e: + print(e) + print("starting compare of cpu vs triton vs torch_npu:") + acc_util.benchmark_compare_close(cpu_res, triton_res, torch_res) + print("PASSED") + + +@avoid_not_support('matmul') +@pytest.mark.parametrize('batch', TestUtils.batch) +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', TestUtils.dtype_list) +def test_batch_matmul(shape, dtype, batch): + M, N, K = shape[0], shape[0], shape[1] + # 32byte/Dtype_bytes + kalign = 32 // get_dtype_size(dtype) + BLOCK_M, BLOCK_N, BLOCK_K = min(max(M, 16), 32), min(max(N, 16), 32), min(max(K, kalign), 32) + + aa = test_common.generate_tensor((batch, M, K), dtype) + bb = test_common.generate_tensor((batch, K, N), dtype) + + if dtype == "int8": + final_triton_res = torch.zeros((batch, M, N), dtype=torch.int32).npu() + accumulator_type = tl.int32 + else: + final_triton_res = torch.zeros((batch, M, N), dtype=eval('torch.' + dtype)).npu() + accumulator_type = tl.float32 + + for i in range(0, batch): + if dtype == "int8": + triton_res = torch.zeros((M, N), dtype=torch.int32).npu() + else: + triton_res = torch.zeros((M, N), dtype=eval('torch.' + dtype)).npu() + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + a = aa[i] + b = bb[i] + matmul_kernel[grid](a.npu(), b.npu(), triton_res, M, N, K, accumulator_type, + a.stride(0), a.stride(1), b.stride(0), b.stride(1), + triton_res.stride(0), triton_res.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K) + final_triton_res[i] = triton_res + + a_gold = aa.to(torch.float32) + b_gold = bb.to(torch.float32) + cpu_res = torch.bmm(a_gold, b_gold) + + if dtype == "int8": + a_npu = aa.npu().to(torch.float32) + b_npu = bb.npu().to(torch.float32) + final_triton_res = final_triton_res.to(torch.float32) + else: + a_npu = aa.npu() + b_npu = bb.npu() + torch_res = torch.bmm(a_npu, b_npu) + + try: + print("starting compare of cpu vs triton:") + acc_util.assert_close(cpu_res, final_triton_res) + except Exception as e: + print(e) + print("starting compare of cpu vs triton vs torch_npu:") + acc_util.benchmark_compare_close(cpu_res, final_triton_res, torch_res) + print("PASSED") + + +if __name__ == "__main__": + test_matmul((16, 32), 'float32') + test_matmul((16, 32), 'int8') + test_batch_matmul(2, (16, 32), 'float32') + test_batch_matmul(2, (16, 32), 'int8') diff --git a/third_party/ascend/examples/generalization_cases/test_max.py b/third_party/ascend/examples/generalization_cases/test_max.py new file mode 100644 index 000000000..db3f71844 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_max.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +import math +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size + +# <<<<<<< test_max_1d +def torch_max(x0, dim, keepdim): + inp = x0 if x0.device == "cpu" else x0.cpu() + return torch.max(inp, dim=dim, keepdim=keepdim)[0].npu() + +@triton.jit +def triton_max_1d(in_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None) + tmp4 = tl.max(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_max_1d(dtype, shape): + if check_ub_mem_overflow(dtype, shape): + return + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu() + numel = shape[0] + triton_max_1d[1,1,1](x0, triton_res, numel, numel) + torch_res = torch_max(x0, dim=0, keepdim=True) + test_common.validate_cmp(dtype, triton_res, torch_res) + +# >>>>>>> test_max_1d + +# <<<<<<< test_max_2d +@triton.jit +def triton_max_2d(in_ptr0, out_ptr0, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:,None]) & (nmask[None,:]) + idx = mblk_idx[:,None] * N + nblk_idx[None,:] + x = tl.load(in_ptr0 + idx, mask = mask, other = -float('inf')) + tmp4 = tl.max(x, dim) + if dim == 0: + tl.store(out_ptr0 + tl.arange(0,N), tmp4, None) + else: + tl.store(out_ptr0 + tl.arange(0,M), tmp4, None) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dim', [0, 1]) +def test_max_2d(dtype, shape, dim): + dtype_size = get_dtype_size(dtype) + if dtype == 'int8' or dtype == 'bool': + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5): + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + shapex, shapey = shape + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[1-dim], ], dtype=eval("torch." + dtype)).npu() + triton_max_2d[1,1,1](x0, triton_res, dim, shapex, shapey, shapex, shapey) + torch_res = torch_max(x0, dim=dim, keepdim=False) + test_common.validate_cmp(dtype, triton_res, torch_res) + +# >>>>>>> test_max_2d + +# <<<<<<< test_max_3d +def torch_max_3d(x0, no_reduce_dim): + inp = x0 if x0.device == "cpu" else x0.cpu() + if no_reduce_dim == 0: + return torch.max(torch.max(inp, 1)[0], 1)[0].npu() + elif no_reduce_dim == 1: + return torch.max(torch.max(inp, 0)[0], 1)[0].npu() + elif no_reduce_dim == 2: + return torch.max(torch.max(inp, 0)[0], 0)[0].npu() + else: + assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" + +@triton.jit +def triton_max_3d_0_1(in_ptr, out_ptr, + xnumel:tl.constexpr, ynumel:tl.constexpr, znumel:tl.constexpr, + XB:tl.constexpr, YB:tl.constexpr, ZB:tl.constexpr): + xidx = tl.arange(0,XB) + yidx = tl.arange(0,YB) + zidx = tl.arange(0,ZB) + idx = xidx[:,None,None]*ynumel*znumel + yidx[None,:,None]*znumel + zidx[None,None,:] + x = tl.load(in_ptr + idx) + tmp = tl.max(x, 0) + ret = tl.max(tmp, 0) + oidx = zidx + tl.store(out_ptr + oidx, ret) + +@triton.jit +def triton_max_3d_0_2(in_ptr, out_ptr, + xnumel:tl.constexpr, ynumel:tl.constexpr, znumel:tl.constexpr, + XB:tl.constexpr, YB:tl.constexpr, ZB:tl.constexpr): + xidx = tl.arange(0,XB) + yidx = tl.arange(0,YB) + zidx = tl.arange(0,ZB) + idx = xidx[:,None,None]*ynumel*znumel + yidx[None,:,None]*znumel + zidx[None,None,:] + x = tl.load(in_ptr + idx) + tmp = tl.max(x, 0) + ret = tl.max(tmp, 1) + oidx = yidx + tl.store(out_ptr + oidx, ret) + +@triton.jit +def triton_max_3d_1_2(in_ptr, out_ptr, + xnumel:tl.constexpr, ynumel:tl.constexpr, znumel:tl.constexpr, + XB:tl.constexpr, YB:tl.constexpr, ZB:tl.constexpr): + xidx = tl.arange(0,XB) + yidx = tl.arange(0,YB) + zidx = tl.arange(0,ZB) + idx = xidx[:,None,None]*ynumel*znumel + yidx[None,:,None]*znumel + zidx[None,None,:] + x = tl.load(in_ptr + idx) + tmp = tl.max(x, 1) + ret = tl.max(tmp, 1) + oidx = xidx + tl.store(out_ptr + oidx, ret) + +def triton_max_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): + if no_reduce_dim == 0: + triton_max_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 1: + triton_max_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 2: + triton_max_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) +def test_max_3d(dtype, shape, no_reduce_dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[no_reduce_dim], ], dtype=eval("torch."+dtype)).npu() + triton_max_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) + torch_res = torch_max_3d(x0, no_reduce_dim) + test_common.validate_cmp(dtype, triton_res, torch_res) + +# >>>>>>> test_max_3d + + +# <<<<<<< test_max_4d +def torch_max_4d(x0, dim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + return torch.max(x0, dim=dim)[0] + + +@triton.jit +def max_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): + if DIM == 0: + ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // XB) + o_idx = tl.arange(0, XB * YB * ZB * MB // XB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 1: + ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // YB) + o_idx = tl.arange(0, XB * YB * ZB * MB // YB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 2: + ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // ZB) + o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) + tl.store(out_ptr + o_idx, ret) + else: + ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB // MB) + o_idx = tl.arange(0, XB * YB * ZB * MB // MB) + tl.store(out_ptr + o_idx, ret) + + +@triton.jit +def triton_max_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + + idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[None, None, :, None] * MB + midx[None, None, None, :] + + x = tl.load(in_ptr + idx) + + max_4d(out_ptr, x, XB, YB, ZB, MB, DIM) + + +def triton_max_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): + triton_max_kernel_4d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, dim) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 2, 4, 8) +]) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dim', [0]) +def test_max_4d(dtype, shape, dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_max_4d(x0, dim) + triton_res = torch.empty_like(torch_res).npu() + triton_max_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) + + test_common.validate_cmp(dtype, triton_res, torch_res) +# >>>>>>> test_max_4d + + +# <<<<<<< test_max_5d +def torch_max_5d(x0, dim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + return torch.max(x0, dim=dim)[0] + + +@triton.jit +def max_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr): + if DIM == 0: + ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // XB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 1: + ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // YB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 2: + ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // ZB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 3: + ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // MB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) + tl.store(out_ptr + o_idx, ret) + else: + ret = tl.reshape(tl.max(x, DIM), XB * YB * ZB * MB * NB // NB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) + tl.store(out_ptr + o_idx, ret) + + +@triton.jit +def triton_max_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + nidx = tl.arange(0, NB) + + idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] + + x = tl.load(in_ptr + idx) + + max_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) + + +def triton_max_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): + triton_max_kernel_5d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 2, 2, 4, 8) +]) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dim', [0]) +def test_max_5d(dtype, shape, dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_max_5d(x0, dim) + triton_res = torch.empty_like(torch_res).npu() + triton_max_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) + + test_common.validate_cmp(dtype, triton_res, torch_res) +# >>>>>>> test_max_5d diff --git a/third_party/ascend/examples/generalization_cases/test_min.py b/third_party/ascend/examples/generalization_cases/test_min.py new file mode 100644 index 000000000..226ab1de5 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_min.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +import math +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size + +# <<<<<<< test_min_1d +def torch_min(x0, dim, keepdim): + inp = x0 if x0.device == "cpu" else x0.cpu() + return torch.min(inp, dim=dim, keepdim=keepdim)[0].npu() + +@triton.jit +def triton_min_1d(in_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None) + tmp4 = tl.min(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +def test_min_1d(dtype, shape): + if check_ub_mem_overflow(dtype, shape): + return + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu() + numel = shape[0] + triton_min_1d[1,1,1](x0, triton_res, numel, numel) + torch_res = torch_min(x0, dim=0, keepdim=True) + test_common.validate_cmp(dtype, triton_res, torch_res) + +# >>>>>>> test_min_1d + +# <<<<<<< test_min_2d +@triton.jit +def triton_min_2d(in_ptr0, out_ptr0, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:,None]) & (nmask[None,:]) + idx = mblk_idx[:,None] * N + nblk_idx[None,:] + x = tl.load(in_ptr0 + idx, mask = mask, other = -float('inf')) + tmp4 = tl.min(x, dim) + if dim == 0: + tl.store(out_ptr0 + tl.arange(0,N), tmp4, None) + else: + tl.store(out_ptr0 + tl.arange(0,M), tmp4, None) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dim', [0, 1]) +def test_min_2d(dtype, shape, dim): + dtype_size = get_dtype_size(dtype) + if dtype == 'int8' or dtype == 'bool': + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5): + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + shapex, shapey = shape + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[1-dim], ], dtype=eval("torch." + dtype)).npu() + triton_min_2d[1,1,1](x0, triton_res, dim, shapex, shapey, shapex, shapey) + torch_res = torch_min(x0, dim=dim, keepdim=False) + test_common.validate_cmp(dtype, triton_res, torch_res) + +# >>>>>>> test_min_2d + +# <<<<<<< test_min_3d +def torch_min_3d(x0, no_reduce_dim): + inp = x0 if x0.device == "cpu" else x0.cpu() + if no_reduce_dim == 0: + return torch.min(torch.min(inp, 1)[0], 1)[0].npu() + elif no_reduce_dim == 1: + return torch.min(torch.min(inp, 0)[0], 1)[0].npu() + elif no_reduce_dim == 2: + return torch.min(torch.min(inp, 0)[0], 0)[0].npu() + else: + assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" + +@triton.jit +def triton_min_3d_0_1(in_ptr, out_ptr, + xnumel:tl.constexpr, ynumel:tl.constexpr, znumel:tl.constexpr, + XB:tl.constexpr, YB:tl.constexpr, ZB:tl.constexpr): + xidx = tl.arange(0,XB) + yidx = tl.arange(0,YB) + zidx = tl.arange(0,ZB) + idx = xidx[:,None,None]*ynumel*znumel + yidx[None,:,None]*znumel + zidx[None,None,:] + x = tl.load(in_ptr + idx) + tmp = tl.min(x, 0) + ret = tl.min(tmp, 0) + oidx = zidx + tl.store(out_ptr + oidx, ret) + +@triton.jit +def triton_min_3d_0_2(in_ptr, out_ptr, + xnumel:tl.constexpr, ynumel:tl.constexpr, znumel:tl.constexpr, + XB:tl.constexpr, YB:tl.constexpr, ZB:tl.constexpr): + xidx = tl.arange(0,XB) + yidx = tl.arange(0,YB) + zidx = tl.arange(0,ZB) + idx = xidx[:,None,None]*ynumel*znumel + yidx[None,:,None]*znumel + zidx[None,None,:] + x = tl.load(in_ptr + idx) + tmp = tl.min(x, 0) + ret = tl.min(tmp, 1) + oidx = yidx + tl.store(out_ptr + oidx, ret) + +@triton.jit +def triton_min_3d_1_2(in_ptr, out_ptr, + xnumel:tl.constexpr, ynumel:tl.constexpr, znumel:tl.constexpr, + XB:tl.constexpr, YB:tl.constexpr, ZB:tl.constexpr): + xidx = tl.arange(0,XB) + yidx = tl.arange(0,YB) + zidx = tl.arange(0,ZB) + idx = xidx[:,None,None]*ynumel*znumel + yidx[None,:,None]*znumel + zidx[None,None,:] + x = tl.load(in_ptr + idx) + tmp = tl.min(x, 1) + ret = tl.min(tmp, 1) + oidx = xidx + tl.store(out_ptr + oidx, ret) + +def triton_min_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): + if no_reduce_dim == 0: + triton_min_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 1: + triton_min_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 2: + triton_min_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) +def test_min_3d(dtype, shape, no_reduce_dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[no_reduce_dim], ], dtype=eval("torch." + dtype)).npu() + triton_min_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) + torch_res = torch_min_3d(x0, no_reduce_dim) + test_common.validate_cmp(dtype, triton_res, torch_res) + +# >>>>>>> test_min_3d + + +# <<<<<<< test_min_4d +def torch_min_4d(x0, dim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + return torch.min(x0, dim=dim)[0] + + +@triton.jit +def min_4d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): + if DIM == 0: + ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // XB) + o_idx = tl.arange(0, XB * YB * ZB * MB // XB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 1: + ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // YB) + o_idx = tl.arange(0, XB * YB * ZB * MB // YB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 2: + ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // ZB) + o_idx = tl.arange(0, XB * YB * ZB * MB // ZB) + tl.store(out_ptr + o_idx, ret) + else: + ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB // MB) + o_idx = tl.arange(0, XB * YB * ZB * MB // MB) + tl.store(out_ptr + o_idx, ret) + + +@triton.jit +def triton_min_kernel_4d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, DIM: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + + idx = xidx[:, None, None, None] * YB * ZB * MB + yidx[None, :, None, None] * ZB * MB + zidx[None, None, :, None] * MB + midx[None, None, None, :] + + x = tl.load(in_ptr + idx) + + min_4d(out_ptr, x, XB, YB, ZB, MB, DIM) + + +def triton_min_4d(in_ptr, out_ptr, XB, YB, ZB, MB, dim): + triton_min_kernel_4d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, dim) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 2, 4, 8) +]) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dim', [0]) +def test_min_4d(dtype, shape, dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_min_4d(x0, dim) + triton_res = torch.empty_like(torch_res).npu() + triton_min_4d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], dim) + + test_common.validate_cmp(dtype, triton_res, torch_res) +# >>>>>>> test_min_4d + + +# <<<<<<< test_min_5d +def torch_min_5d(x0, dim): + x0 = x0 if x0.device == "cpu" else x0.cpu() + if x0.dtype in (torch.int8, torch.int16, torch.int32): + x0 = x0.to(torch.int64) + return torch.min(x0, dim=dim)[0] + + +@triton.jit +def min_5d(out_ptr, x, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr): + if DIM == 0: + ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // XB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // XB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 1: + ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // YB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // YB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 2: + ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // ZB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // ZB) + tl.store(out_ptr + o_idx, ret) + elif DIM == 3: + ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // MB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // MB) + tl.store(out_ptr + o_idx, ret) + else: + ret = tl.reshape(tl.min(x, DIM), XB * YB * ZB * MB * NB // NB) + o_idx = tl.arange(0, XB * YB * ZB * MB * NB // NB) + tl.store(out_ptr + o_idx, ret) + + +@triton.jit +def triton_min_kernel_5d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIM: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + midx = tl.arange(0, MB) + nidx = tl.arange(0, NB) + + idx = xidx[:, None, None, None, None] * YB * ZB * MB * NB + yidx[None, :, None, None, None] * ZB * MB * NB + zidx[None, None, :, None, None] * MB * NB + midx[None, None, None, :, None] * NB + nidx[None, None, None, None, :] + + x = tl.load(in_ptr + idx) + + min_5d(out_ptr, x, XB, YB, ZB, MB, NB, DIM) + + +def triton_min_5d(in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim): + triton_min_kernel_5d[(1,)](in_ptr, out_ptr, XB, YB, ZB, MB, NB, dim) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 2, 2, 4, 8) +]) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dim', [0]) +def test_min_5d(dtype, shape, dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_min_5d(x0, dim) + triton_res = torch.empty_like(torch_res).npu() + triton_min_5d(x0, triton_res, shape[0], shape[1], shape[2], shape[3], shape[4], dim) + + test_common.validate_cmp(dtype, triton_res, torch_res) +# >>>>>>> test_min_5d diff --git a/third_party/ascend/examples/generalization_cases/test_mod.py b/third_party/ascend/examples/generalization_cases/test_mod.py new file mode 100644 index 000000000..d03d71d86 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_mod.py @@ -0,0 +1,207 @@ +import logging + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_pointwise(x, y): + res = x % y + return res + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X % Y + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_mod_4d( + output_ptr, x_ptr, y_ptr, + BLOCK_SIZE: tl.constexpr, SUB_BLOCK: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, +): + pid = tl.program_id(0) + for loop in range(0, tl.cdiv(BLOCK_SIZE, SUB_BLOCK)): + base_idx = tl.arange(0, SUB_BLOCK) + pid_tensor = tl.full((SUB_BLOCK,), pid * BLOCK_SIZE + loop * SUB_BLOCK, dtype=tl.int32) + tmp0 = (pid_tensor + base_idx)[:, None, None, None] + tmp1 = tl.arange(0, SHAPE_1)[None, :, None, None] + tmp2 = tl.arange(0, SHAPE_2)[None, None, :, None] + tmp3 = tl.arange(0, SHAPE_3)[None, None, None, :] + offsets = tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + masks = tmp0 < SHAPE_0 + x = tl.load(x_ptr + offsets, mask=masks) + y = tl.load(y_ptr + offsets, mask=masks) + ret = x % y + tl.store(output_ptr + offsets, ret, mask=masks) + + +@triton.jit +def triton_mod_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_SIZE: tl.constexpr, SUB_BLOCK: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr, +): + pid = tl.program_id(0) + for loop in range(0, tl.cdiv(BLOCK_SIZE, SUB_BLOCK)): + base_idx = tl.arange(0, SUB_BLOCK) + pid_tensor = tl.full((SUB_BLOCK,), pid * BLOCK_SIZE + loop * SUB_BLOCK, dtype=tl.int32) + tmp0 = (pid_tensor + base_idx)[:, None, None, None, None] + tmp1 = tl.arange(0, SHAPE_1)[None, :, None, None, None] + tmp2 = tl.arange(0, SHAPE_2)[None, None, :, None, None] + tmp3 = tl.arange(0, SHAPE_3)[None, None, None, :, None] + tmp4 = tl.arange(0, SHAPE_4)[None, None, None, None, :] + offsets = tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 + masks = tmp0 < SHAPE_0 + x = tl.load(x_ptr + offsets, mask=masks) + y = tl.load(y_ptr + offsets, mask=masks) + ret = x % y + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) +def test_case2(dtype, shape): + if dtype in ['int8', 'int16', 'int32', 'int64']: + x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + z = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + else: + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + + x[x <= 0] = 1 + y[y <= 0] = 1 + z[z <= 0] = 1 + + ans = torch_pointwise(x.cpu(), y.cpu()) + ans = ans.npu() + output = torch.zeros_like(ans) + + if len(shape) == 1: + fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + + +invalid_types = [ + 'bool', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + z = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) + + +@pytest.mark.parametrize('shape', + TestUtils.test_shape4d + [(25, 2, 3, 31), (2, 2, 39, 23), (17, 27, 3, 3), (3, 2, 27, 37)]) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_mod_4d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + if dtype in ['int8', 'int16', 'int32', 'int64']: + x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + else: + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + x[x <= 0] = 1 + y[y <= 0] = 1 + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_pointwise(x.cpu(), y.cpu()) + ans = ans.npu() + + n = x.numel() + block_size = min(triton.next_power_of_2(n), 64) + sub_block_size = 1 + grid = (triton.cdiv(n, block_size),) + print(" ") + print(f"=== loops: {triton.cdiv(block_size, sub_block_size)}") + print(f"=== grid : {grid}") + triton_mod_4d[grid](output, x, y, block_size, sub_block_size, *list(shape), *list(x.stride())) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape5d + [(32, 5, 3, 1, 8)]) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_mod_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + if dtype in ['int8', 'int16', 'int32', 'int64']: + x = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + y = test_common.generate_tensor_int_withSigns(shape, dtype).npu() + else: + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + x[x <= 0] = 1 + y[y <= 0] = 1 + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_pointwise(x.cpu(), y.cpu()) + ans = ans.npu() + + n = x.numel() + block_size = min(triton.next_power_of_2(n), 32) + sub_block_size = 1 + grid = (triton.cdiv(n, block_size),) + print(" ") + print(f"=== loops: {triton.cdiv(block_size, sub_block_size)}") + print(f"=== grid : {grid}") + triton_mod_5d[grid](output, x, y, block_size, sub_block_size, *list(shape), *list(x.stride())) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_ne.py b/third_party/ascend/examples/generalization_cases/test_ne.py new file mode 100644 index 000000000..a736b01af --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_ne.py @@ -0,0 +1,113 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common +from test_common import TestUtils +import math +import logging + + +def torch_ne(x0, x1): + if x0.dtype != torch.uint32: + return x0 != x1 + else: + return x0.to(torch.float32) != x1.to(torch.float32) + + +@triton.jit +def triton_ne(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index, mask=x_index < N) + tmp1 = tl.load(in_ptr1 + x_index, mask=x_index < N) + tmp2 = tmp0 != tmp1 + tl.store(out_ptr0 + x_index, tmp2, mask=x_index < N) + + +@triton.jit +def triton_ne_4d_5d( + x_ptr, y_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val != y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('dtype', ['bool', 'int8', 'int16', 'int32', 'int64', 'float16', 'bfloat16', 'float32']) +def test_ne(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + # 生成数据 + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + + numel = x0.numel() + ncore = 1 if numel <= 32 else 32 + xblock = math.ceil(numel / ncore) + xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) + + # torch结果 + torch_res = torch_ne(x0, x1).to(eval('torch.' + dtype)) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + N = triton_res.numel() + triton_ne[ncore, 1, 1](x0, x1, triton_res, N, xblock, xblock_sub) + # 比较结果 + torch_res = torch_res if dtype != 'uint32' else torch_res.to(torch.float32) + triton_res = triton_res if dtype != 'uint32' else triton_res.to(torch.float32) + cmp_dtype = dtype if dtype != 'uint32' else 'float32' + test_common.validate_cmp(cmp_dtype, triton_res, torch_res) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_ne_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_ne(x, y).to(eval('torch.' + dtype)) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_ne_4d_5d[grid](x, y, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_neg.py b/third_party/ascend/examples/generalization_cases/test_neg.py new file mode 100644 index 000000000..87c3ac531 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_neg.py @@ -0,0 +1,144 @@ +import logging + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_pointwise(x): + res = -x + return res + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = -X + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_neg_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = -x_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) +def test_case2(dtype, shape): + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_pointwise(x.cpu()) + ans = ans.npu() + + if len(shape) == 1: + fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) +def test_neg_4d_5d(shape, dtype): + x = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_pointwise(x.cpu()) + ans = ans.npu() + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_neg_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) + + +invalid_types = [ + 'bool', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + z = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_not.py b/third_party/ascend/examples/generalization_cases/test_not.py new file mode 100644 index 000000000..462612e1a --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_not.py @@ -0,0 +1,160 @@ +import logging + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common +from test_common import TestUtils +import math + + +def torch_not(x0): + res = torch.bitwise_not(x0) + return res + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = not(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_not_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = not(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_not(x) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + +invalid_types = [ + 'float16', + 'float32', + 'bfloat16', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + z = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) + + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_not_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_not(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_not_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_or.py b/third_party/ascend/examples/generalization_cases/test_or.py new file mode 100644 index 000000000..5f2e1dcd3 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_or.py @@ -0,0 +1,145 @@ +import logging + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common +from test_common import TestUtils +import math + + +def torch_or(x0, x1): + return x0 | x1 + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X | Y + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_or_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val | y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + ans = torch_or(x, y) + output = torch.zeros_like(ans) + + if len(shape) == 1: + fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + elif max(shape[0], shape[1], shape[2]) == shape[1]: + fn_npu_[1, shape[1], 1](output, x, y, z, shape[0], 1, shape[2], shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_or_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x | y + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_or_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) + +invalid_types = [ + 'float16', + 'float32', + 'bfloat16', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + z = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) + diff --git a/third_party/ascend/examples/generalization_cases/test_permute_1d_2d.py b/third_party/ascend/examples/generalization_cases/test_permute_1d_2d.py new file mode 100644 index 000000000..0ead1a6cb --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_permute_1d_2d.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils, check_ub_mem_overflow +import math +import logging + +@triton.jit +def fn_npu_1d(output_ptr, x_ptr, xnumel:tl.constexpr): + idx = tl.arange(0, xnumel) + + X = tl.load(x_ptr + idx) + + ret = tl.permute(X, (0,)) + + tl.store(output_ptr + idx, ret) + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', TestUtils.dtype_list) +def test_permute_1d(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + + data_type = eval('torch.' + dtype) + x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() + + triton_res = torch.randint(1, shape, dtype=data_type).npu() + torch_res = torch.permute(x, (0,)) + fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) + test_common.validate_cmp(dtype, triton_res, torch_res) + + +@triton.jit +def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, ynumel: tl.constexpr, znumel:tl.constexpr): + pid = tl.program_id(0) + yidx = tl.arange(0, YB) + pid * YB + zidx = tl.arange(0, ZB) + idx = yidx[:, None] * znumel + zidx[None, :] + + # XB,YB,1 + X = tl.load(x_ptr + idx) + + ret = tl.permute(X, (1, 0)) + + oidx = zidx[:, None] * ynumel + yidx[None, :] + + tl.store(output_ptr + oidx, ret) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', TestUtils.dtype_list) +def test_permute(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + + ynumel=shape[0]; YB = 1 + znumel=shape[1]; ZB = shape[1] + + data_type = eval('torch.' + dtype) + x = torch.randint(low=0, high=2, size=(shape[0], shape[1]), dtype=data_type).npu() + + triton_res = torch.randint(1, (shape[1], shape[0]), dtype=data_type).npu() + torch_res = torch.permute(x, (1, 0)) + fn_npu_021[shape[0], 1, 1](triton_res, x, YB, ZB, ynumel, znumel) + test_common.validate_cmp(dtype, triton_res, torch_res) + +if __name__ == "__main__": + for shape in [(37, 3)]: + for dtype in TestUtils.dtype_list: + test_permute(shape, dtype) diff --git a/third_party/ascend/examples/generalization_cases/test_permute_3d.py b/third_party/ascend/examples/generalization_cases/test_permute_3d.py new file mode 100644 index 000000000..963c55399 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_permute_3d.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils, check_ub_mem_overflow +import math +import logging +@triton.jit +def fn_npu_102(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + kidx = tl.arange(0, KB) + idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.permute(X, (1, 0, 2)) + + oidx = zidx[:, None, None] * YB * KB + yidx[None, :, None] * KB + kidx[None, None, :] + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_210(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + kidx = tl.arange(0, KB) + idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.permute(X, (2, 1, 0)) + + oidx = kidx[:, None, None] * ZB * YB + zidx[None, :, None] * YB + yidx[None, None, :] + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + kidx = tl.arange(0, KB) + idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.permute(X, (0, 2, 1)) + + oidx = yidx[:, None, None] * ZB * KB + kidx[None, :, None] * ZB + zidx[None, None, :] + + tl.store(output_ptr + oidx, ret) + +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('dtype', ["int8", 'int16', 'int32', 'float16', 'float32', 'bfloat16', 'int64']) +def test_permute_3d(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + + data_type = eval('torch.' + dtype) + x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() + + triton_res = torch.empty((shape[1], shape[0], shape[2]), dtype=data_type).npu() + torch_res = torch.permute(x, (1, 0, 2)) + fn_npu_102[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + test_common.validate_cmp(dtype, triton_res, torch_res) + + # not support yet: need bisheng support later + # triton_res = torch.empty((shape[2], shape[1], shape[0]), dtype=data_type).npu() + # torch_res = torch.permute(x, (2, 1, 0)) + # fn_npu_210[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + # test_common.validate_cmp(dtype, triton_res, torch_res) + + triton_res = torch.empty((shape[0], shape[2], shape[1]), dtype=data_type).npu() + torch_res = torch.permute(x, (0, 2, 1)) + fn_npu_021[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + test_common.validate_cmp(dtype, triton_res, torch_res) + diff --git a/third_party/ascend/examples/generalization_cases/test_permute_4d_5d.py b/third_party/ascend/examples/generalization_cases/test_permute_4d_5d.py new file mode 100644 index 000000000..d4a7a9809 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_permute_4d_5d.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils, check_ub_mem_overflow +import math +import logging + + +@triton.jit +def triton_permute_4d( + output_ptr, x_ptr, PERM: tl.constexpr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, +): + pid = tl.program_id(0) + tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] + tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None] + tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] + tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None] + tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None] + tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] + tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None] + tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :] + tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] + tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None] + offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) + x_val = tl.load(x_ptr + offsets, masks) + + if PERM == 0: # 1, 0, 2, 3 + ret = tl.permute(x_val, (1, 0, 2, 3)) + shape0 = SHAPE_1 + shape1 = SHAPE_0 + shape2 = SHAPE_2 + shape3 = SHAPE_3 + elif PERM == 1: # 0, 2, 1, 3 + ret = tl.permute(x_val, (0, 2, 1, 3)) + shape0 = SHAPE_0 + shape1 = SHAPE_2 + shape2 = SHAPE_1 + shape3 = SHAPE_3 + else: # 0, 1, 3, 2 + ret = tl.permute(x_val, (0, 1, 3, 2)) + shape0 = SHAPE_0 + shape1 = SHAPE_1 + shape2 = SHAPE_3 + shape3 = SHAPE_2 + + s3 = 1 + s2 = s3 * shape3 + s1 = s2 * shape2 + s0 = s1 * shape1 + + if PERM == 0: # 1, 0, 2, 3 + out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 + out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) + elif PERM == 1: # 0, 2, 1, 3 + out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 + out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) + else: # 0, 1, 3, 2 + out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 + out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@triton.jit +def triton_permute_5d( + output_ptr, x_ptr, PERM: tl.constexpr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + pid = tl.program_id(0) + tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] + tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] + tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] + tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] + tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] + + tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None, None] + tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None, None] + + tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None, None] + tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None, None] + + tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :, None] + tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None, None] + + tmp3_4 = tl.arange(0, BLOCK_3)[None, None, None, None, :] + tmp4_3 = tl.arange(0, BLOCK_4)[None, None, None, :, None] + + offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 + masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) + x_val = tl.load(x_ptr + offsets, masks) + + if PERM == 0: # 1, 0, 2, 3, 4 + ret = tl.permute(x_val, 1, 0, 2, 3, 4) + shape0 = SHAPE_1 + shape1 = SHAPE_0 + shape2 = SHAPE_2 + shape3 = SHAPE_3 + shape4 = SHAPE_4 + elif PERM == 1: # 0, 2, 1, 3, 4 + ret = tl.permute(x_val, 0, 2, 1, 3, 4) + shape0 = SHAPE_0 + shape1 = SHAPE_2 + shape2 = SHAPE_1 + shape3 = SHAPE_3 + shape4 = SHAPE_4 + elif PERM == 2: # 0, 1, 3, 2, 4 + ret = tl.permute(x_val, 0, 1, 3, 2, 4) + shape0 = SHAPE_0 + shape1 = SHAPE_1 + shape2 = SHAPE_3 + shape3 = SHAPE_2 + shape4 = SHAPE_4 + else: # 0, 1, 2, 4, 3 + ret = tl.permute(x_val, 0, 1, 2, 4, 3) + shape0 = SHAPE_0 + shape1 = SHAPE_1 + shape2 = SHAPE_2 + shape3 = SHAPE_4 + shape4 = SHAPE_3 + + s4 = 1 + s3 = s4 * shape4 + s2 = s3 * shape3 + s1 = s2 * shape2 + s0 = s1 * shape1 + + if PERM == 0: # 1, 0, 2, 3, 4 + out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 + tmp4 * s4 + out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) & (tmp4 < shape4) + elif PERM == 1: # 0, 2, 1, 3, 4 + out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 + tmp4 * s4 + out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) & (tmp4 < shape4) + elif PERM == 2: # 0, 1, 3, 2, 4 + out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 + tmp4 * s4 + out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) & (tmp4 < shape4) + else: # 0, 1, 2, 4, 3 + out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp2 * s2 + tmp4_3 * s3 + tmp3_4 * s4 + out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp2 < shape2) & (tmp4_3 < shape3) & (tmp3_4 < shape4) + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +@pytest.mark.parametrize('perm', [0, 1, 2, 3]) # 4d: support 3 mode; 5d: support 4 mode +def test_permute_4d_5d(shape, dtype, perm): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.randint(low=0, high=2, size=shape, dtype=eval('torch.' + dtype)).npu() + grid = (1,) + if len(shape) == 4: + blocks = list(x.size()) + strides = list(x.stride()) + if perm == 0: # 1, 0, 2, 3; exchange axis 0, 1 + output = torch.empty((shape[1], shape[0], shape[2], shape[3]), dtype=eval('torch.' + dtype)).npu() + ans_4d = torch.permute(x, (1, 0, 2, 3)) + triton_permute_4d[grid](output, x, perm, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans_4d, output) + elif perm == 1: # 0, 2, 1, 3; exchange axis 1, 2 + output = torch.empty((shape[0], shape[2], shape[1], shape[3]), dtype=eval('torch.' + dtype)).npu() + ans_4d = torch.permute(x, (0, 2, 1, 3)) + triton_permute_4d[grid](output, x, perm, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans_4d, output) + elif perm == 2: # 0, 1, 3, 2; exchange axis 2, 3 + output = torch.empty((shape[0], shape[1], shape[3], shape[2]), dtype=eval('torch.' + dtype)).npu() + ans_4d = torch.permute(x, (0, 1, 3, 2)) + triton_permute_4d[grid](output, x, perm, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans_4d, output) + else: + pass + else: + blocks = list(x.size()) + strides = list(x.stride()) + + if perm == 0: # 1, 0, 2, 3, 4; exchange axis 0, 1 + output = torch.empty((shape[1], shape[0], shape[2], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() + ans_5d = torch.permute(x, (1, 0, 2, 3, 4)) + elif perm == 1: # 0, 2, 1, 3, 4; exchange axis 1, 2 + output = torch.empty((shape[0], shape[2], shape[1], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() + ans_5d = torch.permute(x, (0, 2, 1, 3, 4)) + elif perm == 2: # 0, 1, 3, 2, 4; exchange axis 2, 3 + output = torch.empty((shape[0], shape[1], shape[3], shape[2], shape[4]), dtype=eval('torch.' + dtype)).npu() + ans_5d = torch.permute(x, (0, 1, 3, 2, 4)) + else: # 0, 1, 2, 4, 3; exchange axis 3, 4 + output = torch.empty((shape[0], shape[1], shape[2], shape[4], shape[3]), dtype=eval('torch.' + dtype)).npu() + ans_5d = torch.permute(x, (0, 1, 2, 4, 3)) + triton_permute_5d[grid](output, x, perm, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans_5d, output) diff --git a/third_party/ascend/examples/generalization_cases/test_rand.py b/third_party/ascend/examples/generalization_cases/test_rand.py new file mode 100644 index 000000000..eee439f96 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_rand.py @@ -0,0 +1,302 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math +import numpy as np +import scipy + + +@triton.jit +def kernel_rand(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): + block_offset = tl.program_id(0) * XBLOCK + block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset + for inner_idx in range(block_size): + global_offset = block_offset + inner_idx + rand_vals = tl.rand(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 + tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 + + +@triton.jit +def triton_rand_4d_5d( + output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, + BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, + SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr +): + # 1D program_id for flatten multi-d offset + pid = tl.program_id(0) + # base offset for dimension 0 + offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 + mask = tl.arange(0, BLOCK_0) < SHAPE_0 + # nested offset expansion + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + ret = tl.rand(5, offsets, 10) + tl.store(output_ptr + offsets, ret, mask=mask) + + +@triton.jit +def kernel_randn(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): + block_offset = tl.program_id(0) * XBLOCK + block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset + for inner_idx in range(block_size): + global_offset = block_offset + inner_idx + rand_vals = tl.randn(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 + tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 + + +@triton.jit +def triton_randn_4d_5d( + output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, + BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, + SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr +): + # 1D program_id for flatten multi-d offset + pid = tl.program_id(0) + # base offset for dimension 0 + offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 + mask = tl.arange(0, BLOCK_0) < SHAPE_0 + # nested offset expansion + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + ret = tl.randn(5, offsets, 10) + tl.store(output_ptr + offsets, ret, mask=mask) + + +@triton.jit +def kernel_randint(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): + block_offset = tl.program_id(0) * XBLOCK + block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset + for inner_idx in range(block_size): + global_offset = block_offset + inner_idx + rand_vals = tl.randint(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 + tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 + + +@triton.jit +def triton_randint_4d_5d( + output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, + BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, + SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr +): + # 1D program_id for flatten multi-d offset + pid = tl.program_id(0) + # base offset for dimension 0 + offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 + mask = tl.arange(0, BLOCK_0) < SHAPE_0 + # nested offset expansion + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + ret = tl.randint(5, offsets, 10) + tl.store(output_ptr + offsets, ret, mask=mask) + + +@triton.jit +def kernel_randint4x(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): + block_offset = tl.program_id(0) * XBLOCK + indices = tl.arange(0, 4) + block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset + for inner_idx in range(0, block_size + 4, step=4): + global_offset = block_offset + inner_idx + rand_vals = tl.randint4x(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 + mask = (global_offset + indices) < (block_offset + block_size) + tl.store(x_ptr + global_offset + indices, rand_vals, mask) # 存储随机数 + + +@triton.jit +def triton_randint4x_4d_5d( + output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, + BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, + SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, + STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr +): + # 1D program_id for flatten multi-d offset + pid = tl.program_id(0) + # base offset for dimension 0 + offsets = pid + tl.arange(0, BLOCK_0) * STRIDE_0 + mask = tl.arange(0, BLOCK_0) < SHAPE_0 + # nested offset expansion + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + mask = mask[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + mask = mask[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + mask = mask[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + mask = mask[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + ret = tl.randint4x(5, offsets, 10) + tl.store(output_ptr + offsets, ret, mask=mask) + + +# With alpha=0.01, z=-3.0902, N=100, we have (1-0.01)+(-3.0902)*sqrt(0.01*(1-0.01)/100)=0.9593, +# so there must be 96 cases for each shape to have pvalue larger than 0.01. +# There is higher possibility to fail with small shapes, so we will use large shape. +@pytest.mark.parametrize('shape', [ + (256, 256), + (512, 512), + (1024, 1024), +]) +def test_rand_case(shape): + y_calf = torch.zeros(shape, dtype=eval('torch.float32')).npu() + + numel = y_calf.numel() + ncore = 1 if numel < 32 else 32 + xblock = math.ceil(numel / ncore) + + correctness = 0 + for _ in range(100): + ref = np.random.random_sample(shape).flatten() + kernel_rand[ncore, 1, 1](y_calf, 10, numel, xblock) + + pvalue = scipy.stats.kstest(ref, y_calf.cpu().numpy().flatten()).pvalue + if pvalue > 0.01: + correctness += 1 + + assert correctness > 95 + + +@pytest.mark.parametrize('shape', [ + (256, 256), + (512, 512), + (1024, 1024), +]) +def test_randn_case(shape): + y_calf = torch.zeros(shape, dtype=eval('torch.float32')).npu() + + numel = y_calf.numel() + ncore = 1 if numel < 32 else 32 + xblock = math.ceil(numel / ncore) + + correctness = 0 + for _ in range(100): + ref = np.random.standard_normal(shape).flatten() + kernel_randn[ncore, 1, 1](y_calf, 10, numel, xblock) + + pvalue = scipy.stats.kstest(ref, y_calf.cpu().numpy().flatten()).pvalue + if pvalue > 0.01: + correctness += 1 + + assert correctness > 95 + + +@pytest.mark.parametrize('shape', [ + (256, 256), + (512, 512), + (1024, 1024), +]) +def test_randint_case(shape): + y_cali = torch.zeros(shape, dtype=eval('torch.int32')).npu() + + numel = y_cali.numel() + ncore = 1 if numel < 32 else 32 + xblock = math.ceil(numel / ncore) + + correctness = 0 + ii32 = np.iinfo(np.int32) + for _ in range(100): + ref = np.random.randint(low=ii32.min, high=ii32.max, size=shape).flatten() + kernel_randint[ncore, 1, 1](y_cali, 10, numel, xblock) + + pvalue = scipy.stats.kstest(ref, y_cali.cpu().numpy().flatten()).pvalue + if pvalue > 0.01: + correctness += 1 + + assert correctness > 95 + + +@pytest.mark.parametrize('shape', [ + (256, 256), + (512, 512), + (1024, 1024), +]) +def test_randint4x_case(shape): + y_cali = torch.zeros(shape, dtype=eval('torch.int32')).npu() + + numel = y_cali.numel() + ncore = 1 if numel < 32 else 32 + xblock = math.ceil(numel / ncore) + + correctness = 0 + ii32 = np.iinfo(np.int32) + for _ in range(100): + ref = np.random.randint(low=ii32.min, high=ii32.max, size=shape).flatten() + kernel_randint4x[ncore, 1, 1](y_cali, 10, numel, xblock) + + pvalue = scipy.stats.kstest(ref, y_cali.cpu().numpy().flatten()).pvalue + if pvalue > 0.01: + correctness += 1 + + assert correctness > 95 + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +def test_rand_4d_5d(shape): + x = torch.zeros(shape, dtype=eval('torch.float32')).npu() + y = torch.zeros(shape, dtype=eval('torch.int32')).npu() + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_rand_4d_5d[grid](x, *blocks, *blocks, *strides) + triton_randn_4d_5d[grid](x, *blocks, *blocks, *strides) + triton_randint_4d_5d[grid](y, *blocks, *blocks, *strides) + triton_randint4x_4d_5d[grid](y, *blocks, *blocks, *strides) diff --git a/third_party/ascend/examples/generalization_cases/test_range.py b/third_party/ascend/examples/generalization_cases/test_range.py new file mode 100644 index 000000000..1fb52b243 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_range.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest +import triton +import torch +import test_common +import logging + +import triton.language as tl +from test_common import TestUtils + + +@triton.jit +def triton_range(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X + Y + for _ in tl.range(2, 5, 2): + ret = ret + X + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_static_range(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X + Y + for _ in tl.static_range(2, 5, 2): + ret = ret + X + + tl.store(output_ptr + idx, ret) + + +test_shape = [(1,), (2,), (1, 1), (3, 4), (1, 1, 1), (2, 4, 8)] + + +@pytest.mark.parametrize('shape', test_shape) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_range(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + if dtype == 'bfloat16': + ans = (x.to(torch.float32) + y.to(torch.float32) + x.to(torch.float32) + x.to(torch.float32)).to(torch.bfloat16) + else: + ans = x + y + x + x + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if dtype == 'int8': + if x.numel() * x.element_size() >= 512: + grid = (1, 1, ZB) + ZB = 1 + else: + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + triton_range[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', test_shape) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +def test_static_range(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + if dtype == 'bfloat16': + ans = (x.to(torch.float32) + y.to(torch.float32) + x.to(torch.float32) + x.to(torch.float32)).to(torch.bfloat16) + else: + ans = x + y + x + x + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if dtype == 'int8': + if x.numel() * x.element_size() >= 512: + grid = (1, 1, ZB) + ZB = 1 + else: + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + triton_static_range[grid](output, x, y, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_reduce.py b/third_party/ascend/examples/generalization_cases/test_reduce.py new file mode 100644 index 000000000..500644e1e --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_reduce.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +import math +import random +import pytest +import torch +import triton +import triton.language as tl + +import test_common +from test_common import TestUtils, get_dtype_size + + +def torch_reduce(x1, dim): + if x1.dtype == torch.float16 or x1.dtype == torch.float32: + res = torch.sum(x1.to(torch.float32), dim=dim).to(x1.dtype) + else: + res = torch.sum(x1, dim=dim).to(x1.dtype) + return res + + +@triton.jit +def _reduce_combine(a, b): + return a + b + + +@triton.jit +def tt_reduce_1d(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + idx = tl.arange(0, XB) + x = tl.load(in_ptr + idx) + ret = tl.reduce(x, dim, _reduce_combine) + tl.store(out_ptr + tl.arange(0, 1), ret) + + +@triton.jit +def tt_reduce_2d(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + idx = xidx[:, None] * ynumel + yidx[None, :] + + x = tl.load(in_ptr + idx) + ret = tl.reduce(x, dim, _reduce_combine) + + if dim == 0: + oidx = yidx + else: + oidx = xidx + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def tt_reduce_1d_dim_none(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + idx = tl.arange(0, XB) + x = tl.load(in_ptr + idx) + ret = tl.reduce(x, dim, _reduce_combine) + tl.store(out_ptr + tl.arange(0, 1), ret) + + +@triton.jit +def tt_reduce_2d_dim_none(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + idx = xidx[:, None] * ynumel + yidx[None, :] + + x = tl.load(in_ptr + idx) + ret = tl.reduce(x, dim, _reduce_combine) + + tl.store(out_ptr + tl.arange(0, 1), ret) + + +@triton.jit +def tt_reduce_3d_dim_none(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + ret = tl.reduce(x, dim, _reduce_combine) + + tl.store(out_ptr, ret) + + +@triton.jit +def tt_reduce_3d(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + ret = tl.reduce(x, dim, _reduce_combine) + + if dim == 0: + oidx = yidx[:, None] * znumel + zidx[None, :] + elif dim == 1: + oidx = xidx[:, None] * znumel + zidx[None, :] + else: + oidx = xidx[:, None] * ynumel + yidx[None, :] + + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def tt_reduce_3d_0_1(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.reduce(x, 0, _reduce_combine) + ret = tl.reduce(tmp, 0, _reduce_combine) + oidx = zidx + + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def tt_reduce_3d_0_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.reduce(x, 0, _reduce_combine) + ret = tl.reduce(tmp, 1, _reduce_combine) + oidx = yidx + + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def tt_reduce_3d_1_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.reduce(x, 1, _reduce_combine) + ret = tl.reduce(tmp, 1, _reduce_combine) + oidx = xidx + + tl.store(out_ptr + oidx, ret) + + +def is_legal_combine(shape, dims): + return dims is None or (len(shape) == 3) or \ + (len(dims) == 1 and dims[0] < len(shape)) + + +dims_map = { + (0, 1): tt_reduce_3d_0_1, + (1, 2): tt_reduce_3d_1_2, + (0, 2): tt_reduce_3d_0_2 +} + +shape_map = { + 1: {"append_shape": (1, 1), "func": tt_reduce_1d}, + 2: {"append_shape": (1,), "func": tt_reduce_2d}, + 3: {"append_shape": (), "func": tt_reduce_3d} +} + + +def reduce_check_ub_mem_overflow(dtype, shape): + dtype_size = get_dtype_size(dtype) + if (dtype == "int8" or dtype == "bool") and dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): + pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") + elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 6): + pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") + + +@pytest.mark.parametrize('shape', random.sample(TestUtils.full_shape, 5)) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dims', [None, (0,), (1,), (2,), (0, 1), (1, 2), (0, 2)]) +def test_reduce(dtype, shape, dims): + if not is_legal_combine(shape, dims): + return + + torch.manual_seed(0) + x = test_common.generate_tensor(shape, dtype).npu() + grid = (1, 1, 1) + + y_ref = torch_reduce(x, dims) + y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") + + if dims is None: + reduce_check_ub_mem_overflow(dtype, shape) + append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] + xnumel, ynumel, znumel = shape + append_shape + XB, YB, ZB = xnumel, ynumel, znumel + if len(shape) == 1: + tt_reduce_1d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) + if len(shape) == 2: + tt_reduce_2d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) + if len(shape) == 3: + tt_reduce_3d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) + + test_common.validate_cmp(dtype, y_cal, y_ref) + + elif len(dims) == 1: # 1d reduce, 1-3d shape + append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] + xnumel, ynumel, znumel = shape + append_shape + XB, YB, ZB = xnumel, ynumel, znumel + if (len(shape) == 2) and (x.numel() * x.element_size() > 8192): + if dims[0] == 0: + grid = (1, ynumel, 1) + YB = 1 + else: + grid = (xnumel, 1, 1) + XB = 1 + tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) + test_common.validate_cmp(dtype, y_cal, y_ref) + else: # 3d shape, 2d reduce + tt_kernel = dims_map[dims] + xnumel, ynumel, znumel = shape + XB, YB, ZB = xnumel, ynumel, znumel + + tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@triton.jit +def triton_reduce_multi_d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, + NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr, REDUCE_NUMEL: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if DIMS > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if DIMS > 2: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if DIMS > 3: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if DIMS > 4: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + x = tl.load(in_ptr + offsets) + + if DIM is not None: + ret = tl.reshape(tl.reduce(x, DIM, _reduce_combine), REDUCE_NUMEL) + o_offsets = tl.arange(0, REDUCE_NUMEL) + tl.store(out_ptr + o_offsets, ret) + else: + ret = tl.reduce(x, DIM, _reduce_combine) + tl.store(out_ptr, ret) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (4, 2, 8, 4), + (4, 3, 8, 1), +]) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dims', [None, (0,), (1,), (2,), (3,)]) +def test_reduce_4d(dtype, shape, dims): + torch.manual_seed(0) + + x = test_common.generate_tensor(shape, dtype).npu() + dim = dims[0] if dims is not None else None + + y_ref = torch_reduce(x, dim) + y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None + grid = (1,) + triton_reduce_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 2, 8, 4), + (3, 4, 2, 8, 1), +]) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dims', [None, (0,), (1,), (2,), (3,), (4,)]) +def test_reduce_5d(dtype, shape, dims): + torch.manual_seed(0) + + x = test_common.generate_tensor(shape, dtype).npu() + dim = dims[0] if dims is not None else None + + y_ref = torch_reduce(x, dim) + y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None + grid = (1,) + triton_reduce_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/generalization_cases/test_relu.py b/third_party/ascend/examples/generalization_cases/test_relu.py new file mode 100644 index 000000000..d5b9a9071 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_relu.py @@ -0,0 +1,45 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common +import triton.language.extra.ascend.libdevice as libdevice +from test_common import TestUtils +import math + +def torch_relu(x0, x1): + res = x0 + torch.relu(x1) + return res + + +@triton.jit +def triton_relu(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + x_index = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = x_index < xnumel + tmp0 = tl.load(in_ptr0 + x_index, xmask) + tmp1 = tl.load(in_ptr1 + x_index, xmask) + tmp2 = tmp0 + libdevice.relu(tmp1) + tl.store(out_ptr0 + x_index, tmp2, xmask) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['float32', 'float16']) +def test_relu(dtype, shape): + # 生成数据 + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + + numel = x0.numel() + ncore = 1 if numel <= 32 else 32 + xblock = math.ceil(numel / ncore) + xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) + + # torch结果 + torch_res = torch_relu(x0, x1) + # triton结果 + triton_res = test_common.generate_tensor(shape, dtype).npu() + triton_relu[ncore, 1, 1](x0, x1, triton_res, x0.numel(), xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/generalization_cases/test_rshift_op.py b/third_party/ascend/examples/generalization_cases/test_rshift_op.py new file mode 100644 index 000000000..59442c1a5 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_rshift_op.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import logging +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +from test_common import TestUtils + + +@triton.jit +def triton_rshift_1d(in_ptr0, out_ptr0, L : tl.constexpr): + lblk_idx = tl.arange(0,L) + idx = lblk_idx[:] + x0=tl.load(in_ptr0+idx) + ret = x0 >> 2 + odx = lblk_idx[:] + tl.store(out_ptr0+odx, ret) + + +@triton.jit +def triton_rshift_2d(in_ptr0, out_ptr0, M : tl.constexpr, N : tl.constexpr): + moffs = tl.program_id(0) * M + mblk_idx = tl.arange(0,M) + moffs + nblk_idx = tl.arange(0,N) + idx = mblk_idx[:,None]*N+nblk_idx[None,:] + x0=tl.load(in_ptr0+idx) + ret = x0 >> 2 + odx = mblk_idx[:,None]*N+nblk_idx[None,:] + tl.store(out_ptr0+odx, ret) + + +@triton.jit +def triton_rshift_3d(in_ptr0, out_ptr0, L : tl.constexpr, M : tl.constexpr, N : tl.constexpr): + loffs = tl.program_id(0) * L + lblk_idx = tl.arange(0,L) + loffs + mblk_idx = tl.arange(0,M) + nblk_idx = tl.arange(0,N) + idx = lblk_idx[:,None,None]*N*M+mblk_idx[None,:,None]*N+nblk_idx[None,None,:] + x0=tl.load(in_ptr0+idx) + ret = x0 >> 2 + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0+odx, ret) + + +@triton.jit +def triton_rshift_4d_5d( + x_ptr, output_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = x_val >> 2 + tl.store(output_ptr + offsets, ret, mask=masks) + + +dtype_mapping = { + 'int8': (torch.int8), + 'int16': (torch.int16), + 'int32': (torch.int32), + 'uint32': (torch.uint32), + 'int64': (torch.int64), + 'float16': (torch.float16), + 'float32': (torch.float32), + 'bfloat16': (torch.bfloat16), + 'bool': (torch.bool), +} + +typelist = ['int8','int16','int32','int64',] + + +# @pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('shape', TestUtils.test_shape1_2_3d) +@pytest.mark.parametrize('sigtype',typelist) +def test_lshift(sigtype, shape): + dtype = dtype_mapping[sigtype] + x0 = test_common.generate_tensor(shape = shape, dtype = sigtype).npu() + # ncore, xblock, xblock_sub = 2, 32768, 1024 + y_ref = x0 >> 2 + output = torch.zeros(shape, dtype=dtype).npu() + if len(shape) == 3: + shape0 = shape[0] + shape1 = shape[1] + shape2 = shape[2] + if x0.numel() * x0.element_size() >= 1024: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_rshift_3d[grid](x0, output, shape0, shape1, shape2) + if len(shape) == 2: + shape0 = shape[0] + shape1 = shape[1] + if x0.numel() * x0.element_size() >= 1024: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + triton_rshift_2d[grid](x0, output, shape0, shape1) + if len(shape) == 1: + triton_rshift_1d[1, 1, 1](x0, output, shape[0]) + test_common.validate_cmp(sigtype, output, y_ref) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) +def test_rshift_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x >> 2 + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_rshift_4d_5d[grid](x, output, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) + +invalid_types = [ + 'float16', + 'float32', + 'bfloat16', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + triton_rshift_1d[1, 1, 1](x, output, N) + diff --git a/third_party/ascend/examples/generalization_cases/test_scalar_tensor.py b/third_party/ascend/examples/generalization_cases/test_scalar_tensor.py new file mode 100644 index 000000000..14a6c1ba1 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_scalar_tensor.py @@ -0,0 +1,60 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + +def torch_(x0, x1, op_type): + if op_type == 'mul': + return torch.tensor(x0 * x1) + elif op_type == 'lshift': + return torch.tensor(x0 << x1) + elif op_type == 'eq': + return torch.tensor(x0 == x1) + else : + raise TypeError('Invalid op_type') + + +@triton.jit +def scalar_mul(out_ptr0, val0: tl.constexpr, val1: tl.constexpr): + scalar0 = tl.core.tensor(val0, tl.core.block_type(tl.float32, [])) + scalar1 = tl.core.tensor(val1, tl.core.block_type(tl.float32, [])) + ret = scalar0 * scalar1 + tl.store(out_ptr0, ret) + +@triton.jit +def scalar_lshift(out_ptr0, val0: tl.constexpr, val1: tl.constexpr): + scalar0 = tl.core.tensor(val0, tl.core.block_type(tl.int32, [])) + scalar1 = tl.core.tensor(val1, tl.core.block_type(tl.int32, [])) + ret = scalar0 << scalar1 + tl.store(out_ptr0, ret) + +@triton.jit +def scalar_eq(out_ptr0, val0: tl.constexpr, val1: tl.constexpr): + scalar0 = tl.core.tensor(val0, tl.core.block_type(tl.int16, [])) + scalar1 = tl.core.tensor(val1, tl.core.block_type(tl.int16, [])) + ret = scalar0 == scalar1 + tl.store(out_ptr0, ret) + +@pytest.mark.parametrize('param_list', + [ + ['float32', 'mul', (1,), 3.14, 6.66], + ['int32', 'lshift', (1,), 6, 7], + ['bool', 'eq', (1,), 5, 5], + ] + ) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "0d block_type is forbidden") +def test_case(param_list): + dtype, op_type, shape, lval, rval = param_list + ans = torch_(lval, rval, op_type) + ret = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + + if op_type == 'mul': + scalar_mul[1, 1, 1](ret, lval, rval) + elif op_type == 'lshift': + scalar_lshift[1, 1, 1](ret, lval, rval) + elif op_type == 'eq': + scalar_eq[1, 1, 1](ret, lval, rval) + + test_common.validate_cmp(dtype, ans, ret) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_sort.py b/third_party/ascend/examples/generalization_cases/test_sort.py new file mode 100644 index 000000000..5c1a0add3 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_sort.py @@ -0,0 +1,173 @@ +import triton +import pytest +import torch +import triton.language as tl +import test_common +from test_common import TestUtils + + +# ---------------------- +# 1D sort kernel +# ---------------------- +@triton.jit +def sort_kernel_1d(X, Z, M: tl.constexpr, descending: tl.constexpr): + off = tl.arange(0, M) + x = tl.load(X + off) + x = tl.sort(x, descending=descending, dim=0) + tl.store(Z + off, x) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape", TestUtils.test_shape1d) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16"]) +def test_sort_1d(shape, descending, dtype): + x = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch.sort(x, descending=descending, dim=0)[0] + + triton_res = torch.zeros_like(x) + M = x.shape[0] + sort_kernel_1d[(1, )](x, triton_res, M, descending) + assert torch.equal(torch_res, triton_res) + + +# ---------------------- +# 2D sort kernel (split by rows, not cutting M axis) +# ---------------------- +@triton.jit +def sort_kernel_2d(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): + pid = tl.program_id(0) + offx = tl.arange(0, M) + offy = pid * M + off2d = offx + offy + x = tl.load(X + off2d) + x = tl.sort(x, descending=descending, dim=0) + tl.store(Z + off2d, x) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape", TestUtils.test_shape2d) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16"]) +def test_sort_2d(shape, descending, dtype): + x = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch.sort(x, descending=descending, dim=1)[0] + + triton_res = torch.zeros_like(x) + N, M = x.shape + # 每行一个 block + sort_kernel_2d[(N, )](x, triton_res, N, M, descending) + assert torch.equal(torch_res, triton_res), (torch_res, triton_res) + + +# ---------------------- +# 3D sort kernel (split by D0, D1, not cutting D2) +# ---------------------- +@triton.jit +def sort_kernel_3d(X, Z, D0: tl.constexpr, D1: tl.constexpr, D2: tl.constexpr, descending: tl.constexpr): + pid = tl.program_id(0) + row_id = pid % D1 + batch_id = pid // D1 + + off2 = tl.arange(0, D2) + off1 = row_id * D2 + off0 = batch_id * D1 * D2 + off = off2 + off1 + off0 + + x = tl.load(X + off) + x = tl.sort(x, descending=descending, dim=0) # 一整行排序 + tl.store(Z + off, x) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape", TestUtils.test_shape3d) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16"]) +def test_sort_3d(shape, descending, dtype): + x = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch.sort(x, descending=descending, dim=2)[0] + + triton_res = torch.zeros_like(x) + D0, D1, D2 = x.shape + # 每个 (D0,D1) 对应一个 block + sort_kernel_3d[(D0 * D1, )](x, triton_res, D0, D1, D2, descending) + assert torch.equal(torch_res, triton_res), (torch_res, triton_res) + + +# ---------------------- +# 4D sort kernel +# ---------------------- +@triton.jit +def sort_kernel_4d(X, Z, + D0: tl.constexpr, D1: tl.constexpr, + D2: tl.constexpr, D3: tl.constexpr, + descending: tl.constexpr): + pid = tl.program_id(0) + row_id = pid % D2 + col_id = (pid // D2) % D1 + batch_id = pid // (D1 * D2) + + off3 = tl.arange(0, D3) + off2 = row_id * D3 + off1 = col_id * D2 * D3 + off0 = batch_id * D1 * D2 * D3 + off = off3 + off2 + off1 + off0 + + x = tl.load(X + off) + x = tl.sort(x, descending=descending, dim=0) + tl.store(Z + off, x) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape", TestUtils.test_shape4d) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16"]) +def test_sort_4d(shape, descending, dtype): + x = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch.sort(x, descending=descending, dim=3)[0] + + triton_res = torch.zeros_like(x) + D0, D1, D2, D3 = x.shape + sort_kernel_4d[(D0 * D1 * D2, )](x, triton_res, D0, D1, D2, D3, descending) + assert torch.equal(torch_res, triton_res) + + +# ---------------------- +# 5D sort kernel +# ---------------------- +@triton.jit +def sort_kernel_5d(X, Z, + D0: tl.constexpr, D1: tl.constexpr, + D2: tl.constexpr, D3: tl.constexpr, + D4: tl.constexpr, + descending: tl.constexpr): + pid = tl.program_id(0) + row_id = pid % D3 + col_id = (pid // D3) % D2 + depth_id = (pid // (D2 * D3)) % D1 + batch_id = pid // (D1 * D2 * D3) + + off4 = tl.arange(0, D4) + off3 = row_id * D4 + off2 = col_id * D3 * D4 + off1 = depth_id * D2 * D3 * D4 + off0 = batch_id * D1 * D2 * D3 * D4 + off = off4 + off3 + off2 + off1 + off0 + + x = tl.load(X + off) + x = tl.sort(x, descending=descending, dim=0) + tl.store(Z + off, x) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape", TestUtils.test_shape5d) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype", ["int8", "int16", "float16", "float32", "bfloat16"]) +def test_sort_5d(shape, descending, dtype): + x = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch.sort(x, descending=descending, dim=4)[0] + + triton_res = torch.zeros_like(x) + D0, D1, D2, D3, D4 = x.shape + sort_kernel_5d[(D0 * D1 * D2 * D3, )](x, triton_res, D0, D1, D2, D3, D4, descending) + assert torch.equal(torch_res, triton_res) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_sqrt.py b/third_party/ascend/examples/generalization_cases/test_sqrt.py new file mode 100644 index 000000000..2d5c8f753 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_sqrt.py @@ -0,0 +1,161 @@ +import logging + +import triton +import triton.language as tl +import torch +import numpy as np +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_sqrt(x0): + res = torch.sqrt(x0) + return res + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.sqrt(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_sqrt_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.sqrt(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_sqrt(x) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +invalid_types = [ + 'int8', + 'int16', + 'int32', + 'uint32', + 'int64', + 'bool', +] + + +@pytest.mark.parametrize("dtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") +def test_sqrt_invalid_dtype_case(dtype): + x = test_common.generate_tensor((1,), dtype).npu() + y = test_common.generate_tensor((1,), dtype).npu() + z = test_common.generate_tensor((1,), dtype).npu() + + output = torch.randint(1, (1,), dtype=eval('torch.' + dtype)).npu() + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_sqrt_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch.sqrt(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_sqrt_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_sqrt_rn.py b/third_party/ascend/examples/generalization_cases/test_sqrt_rn.py new file mode 100644 index 000000000..15f7b8564 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_sqrt_rn.py @@ -0,0 +1,160 @@ +import logging + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_sqrt_rn(x0): + tmp = torch.sqrt(x0) + return tmp + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = tl.sqrt_rn(X) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_sqrt_rn_4d_5d( + output_ptr, x_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + ret = tl.sqrt_rn(x_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_sqrt_rn(x) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +invalid_types = [ + 'int8', + 'int16', + 'int32', + 'uint32', + 'int64', + 'bool', +] + + +@pytest.mark.parametrize("dtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") +def test_sqrt_rn_invalid_dtype_case(dtype): + x = test_common.generate_tensor((1,), dtype).npu() + y = test_common.generate_tensor((1,), dtype).npu() + z = test_common.generate_tensor((1,), dtype).npu() + + output = torch.randint(1, (1,), dtype=eval('torch.' + dtype)).npu() + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) +def test_sqrt_rn_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = torch_sqrt_rn(x) + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_sqrt_rn_4d_5d[grid](output, x, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_static_print_and_assert.py b/third_party/ascend/examples/generalization_cases/test_static_print_and_assert.py new file mode 100644 index 000000000..bfce6c126 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_static_print_and_assert.py @@ -0,0 +1,141 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common +import functools +import os +import re + +shape = (8,) +XS = 8 + +XVALS_INT = [0, + -128, # torch.iinfo(torch.int8).min + 127, # torch.iinfo(torch.int8).max + -32768, # torch.iinfo(torch.int16).min + 32767, # torch.iinfo(torch.int16).max + -2147483648, # torch.iinfo(torch.int32).min + 2147483647, # torch.iinfo(torch.int32).max + 9223372036854775807] # torch.iinfo(torch.int64).max + +XVALS_FP = [0.0000000000e+00, # 0 + 1.1921000009e-07, # torch.finfo(torch.float32).eps + 9.7655999707e-04, # torch.finfo(torch.float16).eps + 7.8125000000e-03, # torch.finfo(torch.bfloat16).eps + 3.4027999388e+38, # torch.finfo(torch.float32).max + 6.5504000000e+04, # torch.finfo(torch.float16).max + 3.3894999515e+38, # torch.finfo(torch.bfloat16).max + 1.0000000000e+00] # 1 + + +def torch_func(x0, x1): + res = x0 + x1 + return res + + +@triton.jit +def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr, print_data_ptr: tl.constexpr, + assert_data_ptr: tl.constexpr): + idx = tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 + tmp1 + tl.static_print(print_data_ptr) + tl.static_assert(assert_data_ptr == assert_data_ptr, "assert_data should equal assert_data") + tl.store(out_ptr0 + idx, tmp2) + + +def triton_func(x0, x1, XS, print_data_ptr, assert_data_ptr): + out = torch.empty_like(x0) + triton_kernel[1, 1, 1](out, x0, x1, XS, print_data_ptr, assert_data_ptr) + return out + + +@pytest.mark.parametrize('sigtype', ['int8']) +@test_common.capture_output("-128") +def test_static_print_int8(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, -128, XVALS_INT[0]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['int16']) +@test_common.capture_output("-32768") +def test_static_print_int16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, -32768, XVALS_INT[2]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['int32']) +@test_common.capture_output("-2147483648") +def test_static_print_int32(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, -2147483648, XVALS_INT[4]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['int64']) +@test_common.capture_output("9223372036854775807") +def test_static_print_int64(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, 9223372036854775807, XVALS_INT[-1]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['float16']) +@test_common.capture_output("1.1921000009e-07") +def test_static_print_float16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, 1.1921000009e-07, XVALS_FP[1]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['float32']) +@test_common.capture_output("0.0078125") +def test_static_print_float32(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, 7.8125000000e-03, XVALS_FP[0]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['bfloat16']) +@test_common.capture_output("0.00097655999707") +def test_static_print_bfloat16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, 9.7655999707e-04, XVALS_FP[2]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['int8']) +@test_common.capture_output("True") +def test_static_print_bool(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, True, True) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_sum.py b/third_party/ascend/examples/generalization_cases/test_sum.py new file mode 100644 index 000000000..fc5e34071 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_sum.py @@ -0,0 +1,328 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +import math +import random +import pytest +import torch +import triton +import triton.language as tl + +import test_common +from test_common import TestUtils, get_dtype_size + + +def torch_sum(x1, dim): + if x1.dtype == torch.float16 or x1.dtype == torch.bfloat16: + res = torch.sum(x1.to(torch.float32), dim=dim, keepdim=False).to(x1.dtype) + else: + res = torch.sum(x1, dim=dim, keepdim=False).to(x1.dtype) + return res + + +@triton.jit +def tt_sum_1d(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + idx = tl.arange(0, XB) + x = tl.load(in_ptr + idx) + ret = tl.sum(x, dim) + tl.store(out_ptr + tl.arange(0, 1), ret) + + +@triton.jit +def tt_sum_2d(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + idx = xidx[:, None] * ynumel + yidx[None, :] + + x = tl.load(in_ptr + idx) + ret = tl.sum(x, dim) + + if dim == 0: + oidx = yidx + else: + oidx = xidx + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def tt_sum_1d_dim_none(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + idx = tl.arange(0, XB) + x = tl.load(in_ptr + idx) + ret = tl.sum(x, dim) + tl.store(out_ptr + tl.arange(0, 1), ret) + + +@triton.jit +def tt_sum_2d_dim_none(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + idx = xidx[:, None] * ynumel + yidx[None, :] + + x = tl.load(in_ptr + idx) + ret = tl.sum(x, dim) + + tl.store(out_ptr + tl.arange(0, 1), ret) + + +@triton.jit +def tt_sum_3d_dim_none(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + ret = tl.sum(x, dim) + + tl.store(out_ptr, ret) + + +@triton.jit +def tt_sum_3d(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + ret = tl.sum(x, dim) + + if dim == 0: + oidx = yidx[:, None] * znumel + zidx[None, :] + elif dim == 1: + oidx = xidx[:, None] * znumel + zidx[None, :] + else: + oidx = xidx[:, None] * ynumel + yidx[None, :] + + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def tt_sum_3d_0_1(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.sum(x, 0) + ret = tl.sum(tmp, 0) + oidx = zidx + + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def tt_sum_3d_0_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.sum(x, 0) + ret = tl.sum(tmp, 1) + oidx = yidx + + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def tt_sum_3d_1_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, dim: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + + x = tl.load(in_ptr + idx) + + tmp = tl.sum(x, 1) + ret = tl.sum(tmp, 1) + oidx = xidx + + tl.store(out_ptr + oidx, ret) + + +def is_legal_combine(shape, dims): + return dims is None or (len(shape) == 3) or \ + (len(dims) == 1 and dims[0] < len(shape)) + + +dims_map = { + (0, 1): tt_sum_3d_0_1, + (1, 2): tt_sum_3d_1_2, + (0, 2): tt_sum_3d_0_2 +} + +shape_map = { + 1: {"append_shape": (1, 1), "func": tt_sum_1d}, + 2: {"append_shape": (1,), "func": tt_sum_2d}, + 3: {"append_shape": (), "func": tt_sum_3d} +} + + +def reduce_check_ub_mem_overflow(dtype, shape): + dtype_size = get_dtype_size(dtype) + if (dtype == "int8" or dtype == "bool") and dtype_size * math.prod(shape) >= (TestUtils.ub_size / 20): + pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") + elif dtype_size * math.prod(shape) >= (TestUtils.ub_size / 6): + pytest.skip("dtype:{dtype} shape:{shape} mem overflow, skipping.") + + +@pytest.mark.parametrize('shape', random.sample(TestUtils.full_shape, 5)) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dims', [None, (0,), (1,), (2,), (0, 1), (1, 2), (0, 2)]) +def test_sum(dtype, shape, dims): + if not is_legal_combine(shape, dims): + return + + torch.manual_seed(0) + x = test_common.generate_tensor(shape, dtype).npu() + grid = (1, 1, 1) + + y_ref = torch_sum(x, dims) + y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") + + if dims is None: + reduce_check_ub_mem_overflow(dtype, shape) + append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] + xnumel, ynumel, znumel = shape + append_shape + XB, YB, ZB = xnumel, ynumel, znumel + if len(shape) == 1: + tt_sum_1d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) + if len(shape) == 2: + tt_sum_2d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) + if len(shape) == 3: + tt_sum_3d_dim_none[1, 1, 1](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims) + + test_common.validate_cmp(dtype, y_cal, y_ref) + + elif len(dims) == 1: # 1d sum, 1-3d shape + append_shape, tt_kernel = shape_map[len(shape)]["append_shape"], shape_map[len(shape)]["func"] + xnumel, ynumel, znumel = shape + append_shape + XB, YB, ZB = xnumel, ynumel, znumel + if (len(shape) == 2) and (x.numel() * x.element_size() > 8192): + if dims[0] == 0: + grid = (1, ynumel, 1) + YB = 1 + else: + grid = (xnumel, 1, 1) + XB = 1 + tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) + test_common.validate_cmp(dtype, y_cal, y_ref) + else: # 3d shape, 2d sum + tt_kernel = dims_map[dims] + xnumel, ynumel, znumel = shape + XB, YB, ZB = xnumel, ynumel, znumel + tt_kernel[grid](x, y_cal, xnumel, ynumel, znumel, XB, YB, ZB, dims[0]) + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@triton.jit +def triton_sum_multi_d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, + NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr, REDUCE_NUMEL: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if DIMS > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if DIMS > 2: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if DIMS > 3: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if DIMS > 4: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + x = tl.load(in_ptr + offsets) + + if DIM is not None: + ret = tl.reshape(tl.sum(x, DIM), REDUCE_NUMEL) + o_offsets = tl.arange(0, REDUCE_NUMEL) + tl.store(out_ptr + o_offsets, ret) + else: + ret = tl.sum(x, DIM) + tl.store(out_ptr, ret) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (4, 2, 8, 4), + (4, 3, 8, 1), +]) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dims', [None, (0,), (1,), (2,), (3,)]) +def test_sum_4d(dtype, shape, dims): + torch.manual_seed(0) + + x = test_common.generate_tensor(shape, dtype).npu() + dim = dims[0] if dims is not None else None + + y_ref = torch_sum(x, dim) + y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None + grid = (1,) + triton_sum_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 2, 8, 4), + (3, 4, 2, 8, 1), +]) +@pytest.mark.parametrize('dtype', TestUtils.full_dtype) +@pytest.mark.parametrize('dims', [None, (0,), (1,), (2,), (3,), (4,)]) +def test_sum_5d(dtype, shape, dims): + torch.manual_seed(0) + + x = test_common.generate_tensor(shape, dtype).npu() + dim = dims[0] if dims is not None else None + + y_ref = torch_sum(x, dim) + y_cal = torch.empty(y_ref.shape, dtype=eval('torch.' + dtype), device="npu") + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None + grid = (1,) + triton_sum_multi_d[grid](x, y_cal, *triton_shape, len(shape), dim, reduce_numel) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/generalization_cases/test_sum_dim0.py b/third_party/ascend/examples/generalization_cases/test_sum_dim0.py new file mode 100644 index 000000000..432bfed8d --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_sum_dim0.py @@ -0,0 +1,50 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils, get_dtype_size +import math + +def torch_sum(x0): + res = torch.sum(x0, 0) + return res + +@triton.jit +def triton_sum(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr, RBLOCK_SUB : tl.constexpr): + xindex = tl.arange(0, XBLOCK) + xmask = xindex[:, None] < xnumel + for roffset_sub in range(0, RBLOCK, RBLOCK_SUB): + rindex = roffset_sub + tl.arange(0, RBLOCK_SUB) + x0 = xindex + r1 = rindex + rmask = rindex < rnumel + tmp0 = tl.load(in_ptr0 + (r1 + (RBLOCK*x0[:, None])), xmask & rmask) + tmp2 = tl.reshape(tmp0, [XBLOCK, RBLOCK_SUB]) + tmp4 = tl.sum(tmp2, 0) + tl.store(out_ptr1 + (rindex), tmp4, rmask) + +def should_skip_due_to_mem(dtype, shape): + dtype_size = get_dtype_size(dtype) + total_mem = dtype_size * math.prod(shape) + threshold = TestUtils.ub_size / 1.5 + + if total_mem >= threshold: + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['float32', 'int32']) +def test_case(dtype, shape): + should_skip_due_to_mem(dtype, shape) + x0 = test_common.generate_tensor(shape, dtype).npu() + + rblock = shape[1] + xblock = shape[0] + ncore = 1 #if numel <= 32 else 32 + rblock_sub = rblock #if xblock <= 16 else 16 + RBLOCK_tl = 256 if rblock > 1 else 1 + + y_ref = torch_sum(x0) + y_cal = torch.zeros(shape[1], dtype=eval('torch.' + dtype)).npu() + triton_sum[ncore, 1, 1](x0, y_cal, xblock, rblock, xblock, rblock, rblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_sum_dim1.py b/third_party/ascend/examples/generalization_cases/test_sum_dim1.py new file mode 100644 index 000000000..89d4945fc --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_sum_dim1.py @@ -0,0 +1,43 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + +def torch_sum(x0): + res = torch.sum(x0, 1) + return res + +@triton.jit +def triton_sum(in_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr, RBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + rindex = tl.arange(0, RBLOCK)[None, :] + rmask = rindex < rnumel + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB) + x0 = xindex + r1 = rindex + xmask = xindex[:, None] < xnumel + xmask_prime = xindex < xnumel + tmp0 = tl.load(in_ptr0 + (r1 + (RBLOCK*x0[:, None])), rmask & xmask) + tmp2 = tl.reshape(tmp0, [XBLOCK_SUB, RBLOCK]) + tmp4 = tl.sum(tmp2, 1) + tl.store(out_ptr1 + (xindex), tmp4, xmask_prime) + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['float32', 'float16', 'int32']) +def test_case(dtype, shape): + x0 = test_common.generate_tensor(shape, dtype).npu() + + rblock = shape[1] + xblock = shape[0] + ncore = 1 #if numel <= 32 else 32 + xblock_sub = xblock if xblock <= 16 else 16 + RBLOCK_tl = 256 if rblock > 1 else 1 + + y_ref = torch_sum(x0) + y_cal = torch.zeros(shape[:-1], dtype=eval('torch.' + dtype)).npu() + triton_sum[ncore, 1, 1](x0, y_cal, xblock, rblock, xblock, xblock_sub, rblock) + test_common.validate_cmp(dtype, y_cal, y_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_swizzle2d.py b/third_party/ascend/examples/generalization_cases/test_swizzle2d.py new file mode 100644 index 000000000..b1bb5717d --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_swizzle2d.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import random +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common +from test_common import TestUtils + + +def swizzle2d(size_i, size_j, size_g): + i = torch.arange(0, size_i)[:, None] + j = torch.arange(0, size_j)[None, :] + ij = i * size_j + j + size_gj = size_g * size_j + group_id = ij // size_gj + off_i = group_id * size_g + size_g = torch.min(size_i - off_i, torch.tensor(size_g).expand_as(off_i)) + ij = ij % size_gj + new_i = off_i + ij % size_g + new_j = ij // size_g + ret = new_i * size_i + new_j + return ret + + +@triton.jit +def fn_npu_(out0, out1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + i = tl.arange(0, XB)[:, None] + j = tl.arange(0, YB)[None, :] + ij = i * YB + j + xx, yy = tl.swizzle2d(i, j, size_i=XB, size_j=YB, size_g=ZB) + + ptr = tl.load(out0) + xx = tl.cast(xx, dtype=ptr.dtype) + yy = tl.cast(yy, dtype=ptr.dtype) + tl.store(out0 + ij, xx) + tl.store(out1 + ij, yy) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) +def test_swizzle2d(shape, dtype): + if (shape[0] > 255) or (shape[1] > 255): + return + size_g = random.randint(1, min(shape[0], shape[1])) + ans = swizzle2d(shape[0], shape[1], size_g).to(eval('torch.' + dtype)).npu() + + out0 = test_common.generate_tensor(shape, dtype).npu() + out1 = test_common.generate_tensor(shape, dtype).npu() + fn_npu_[1, 1, 1](out0, out1, shape[0], shape[1], size_g) + triton_ret = out0 * shape[0] + out1 + torch.testing.assert_close(triton_ret, ans) diff --git a/third_party/ascend/examples/generalization_cases/test_tan.py b/third_party/ascend/examples/generalization_cases/test_tan.py new file mode 100644 index 000000000..a7f27c3ae --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_tan.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math +import triton.language.extra.ascend.libdevice as libdevice +def torch_pointwise(x0): + res = torch.tan(x0) + return res + +@triton.jit +def triton_tan(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp2 = libdevice.tan(tmp0) + tl.store(out_ptr0 + (x0), tmp2, None) + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['float32', 'float16']) +def test_case(dtype, shape): + x0 = test_common.generate_tensor(shape, dtype).npu() + + numel = x0.numel() + ncore = 1 if numel <= 32 else 32 + xblock = math.ceil(numel / ncore) + xblock_sub = numel if numel <= ncore else math.ceil(numel / ncore) + + y_ref = torch_pointwise(x0) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_tan[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/generalization_cases/test_trans_1d_2d.py b/third_party/ascend/examples/generalization_cases/test_trans_1d_2d.py new file mode 100644 index 000000000..1b12c69d6 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_trans_1d_2d.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils, check_ub_mem_overflow +import math +import logging + +@triton.jit +def fn_npu_1d(output_ptr, x_ptr, xnumel:tl.constexpr): + idx = tl.arange(0, xnumel) + + X = tl.load(x_ptr + idx) + + ret = tl.trans(X, 0) + + tl.store(output_ptr + idx, ret) + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', TestUtils.dtype_list) +def test_trans_1d(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + + data_type = eval('torch.' + dtype) + x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() + + triton_res = torch.randint(1, shape, dtype=data_type).npu() + torch_res = torch.permute(x, (0,)) + fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) + test_common.validate_cmp(dtype, triton_res, torch_res) + +@triton.jit +def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = yidx[:, None] * ZB + zidx[None, :] + + # XB,YB,1 + X = tl.load(x_ptr + idx) + + ret = tl.trans(X, 1, 0) + + oidx = zidx[:, None] * YB + yidx[None, :] + + tl.store(output_ptr + oidx, ret) + +bisheng_notsupport_dtype = ['int64'] +tritonascend_notsupport_dtype = ['bool'] +# check_ub_mem_overflow没拦住,在kernel中最大ub占用超过ubsize +mem_overflow_scene = [ +('bfloat16', (128, 256)), +('bfloat16', (256, 128)), +('int8', (741,256)), +('int8', (256,741)), +('int16', (256,256)), +('float16', (256,256)), +('bfloat16', (256,256)), +('int32', (128, 256)), +('int32', (256, 128)), +('float32', (128,256)), +('float32', (256,128)), +] +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', TestUtils.dtype_list) +def test_permute(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + if dtype in bisheng_notsupport_dtype or dtype in tritonascend_notsupport_dtype: + return + if (dtype, shape) in mem_overflow_scene: + return + if check_ub_mem_overflow(dtype, shape): + return + YB = shape[0] + ZB = shape[1] + data_type = eval('torch.' + dtype) + x = torch.randint(low=0, high=2, size=(YB, ZB), dtype=data_type).npu() + + triton_res = torch.randint(1, (ZB, YB), dtype=data_type).npu() + torch_res = torch.permute(x, (1, 0)) + fn_npu_021[1, 1, 1](triton_res, x, YB, ZB) + test_common.validate_cmp(dtype, triton_res, torch_res) + +if __name__ == "__main__": + for shape in [(37, 3)]: + for dtype in TestUtils.dtype_list: + test_permute(shape, dtype) diff --git a/third_party/ascend/examples/generalization_cases/test_trans_3d.py b/third_party/ascend/examples/generalization_cases/test_trans_3d.py new file mode 100644 index 000000000..38cafd9dd --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_trans_3d.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils, check_ub_mem_overflow +import math +import logging +@triton.jit +def fn_npu_102(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + kidx = tl.arange(0, KB) + idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.trans(X, 1, 0, 2) + + oidx = zidx[:, None, None] * YB * KB + yidx[None, :, None] * KB + kidx[None, None, :] + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_210(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + kidx = tl.arange(0, KB) + idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.trans(X, 2, 1, 0) + + oidx = kidx[:, None, None] * ZB * YB + zidx[None, :, None] * YB + yidx[None, None, :] + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + kidx = tl.arange(0, KB) + idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.trans(X, 0, 2, 1) + + oidx = yidx[:, None, None] * ZB * KB + kidx[None, :, None] * ZB + zidx[None, None, :] + + tl.store(output_ptr + oidx, ret) + +bisheng_notsupport_dtype = [] +tritonascend_notsupport_dtype = ['bool'] +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('dtype', TestUtils.dtype_list) +def test_permute_3d(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + if dtype in bisheng_notsupport_dtype or dtype in tritonascend_notsupport_dtype: + return + if check_ub_mem_overflow(dtype, shape): + return + + data_type = eval('torch.' + dtype) + x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() + + triton_res = torch.empty((shape[1], shape[0], shape[2]), dtype=data_type).npu() + torch_res = torch.permute(x, (1, 0, 2)) + fn_npu_102[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + test_common.validate_cmp(dtype, triton_res, torch_res) + + # not support yet: need bisheng support later + # triton_res = torch.empty((shape[2], shape[1], shape[0]), dtype=data_type).npu() + # torch_res = torch.permute(x, (2, 1, 0)) + # fn_npu_210[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + # test_common.validate_cmp(dtype, triton_res, torch_res) + + triton_res = torch.empty((shape[0], shape[2], shape[1]), dtype=data_type).npu() + torch_res = torch.permute(x, (0, 2, 1)) + fn_npu_021[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + test_common.validate_cmp(dtype, triton_res, torch_res) + +if __name__ == "__main__": + for shape in [(1, 22, 39)]: + for dtype in TestUtils.dtype_list: + test_permute_3d(shape, dtype) diff --git a/third_party/ascend/examples/generalization_cases/test_trans_4d_5d.py b/third_party/ascend/examples/generalization_cases/test_trans_4d_5d.py new file mode 100644 index 000000000..fbca7a229 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_trans_4d_5d.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils, check_ub_mem_overflow +import math +import logging + + +@triton.jit +def triton_trans_4d( + output_ptr, x_ptr, PERM: tl.constexpr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, +): + pid = tl.program_id(0) + tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None] + tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None] + tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None] + tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None] + tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None] + tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None] + tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None] + tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :] + tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :] + tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None] + offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) + x_val = tl.load(x_ptr + offsets, masks) + + if PERM == 0: # 1, 0, 2, 3 + ret = tl.trans(x_val, (1, 0, 2, 3)) + shape0 = SHAPE_1 + shape1 = SHAPE_0 + shape2 = SHAPE_2 + shape3 = SHAPE_3 + elif PERM == 1: # 0, 2, 1, 3 + ret = tl.trans(x_val, (0, 2, 1, 3)) + shape0 = SHAPE_0 + shape1 = SHAPE_2 + shape2 = SHAPE_1 + shape3 = SHAPE_3 + else: # 0, 1, 3, 2 + ret = tl.trans(x_val, (0, 1, 3, 2)) + shape0 = SHAPE_0 + shape1 = SHAPE_1 + shape2 = SHAPE_3 + shape3 = SHAPE_2 + + s3 = 1 + s2 = s3 * shape3 + s1 = s2 * shape2 + s0 = s1 * shape1 + + if PERM == 0: # 1, 0, 2, 3 + out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 + out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) + elif PERM == 1: # 0, 2, 1, 3 + out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 + out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) + else: # 0, 1, 3, 2 + out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 + out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@triton.jit +def triton_trans_5d( + output_ptr, x_ptr, PERM: tl.constexpr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + pid = tl.program_id(0) + tmp0 = tl.arange(0, BLOCK_0)[:, None, None, None, None] + tmp1 = tl.arange(0, BLOCK_1)[None, :, None, None, None] + tmp2 = tl.arange(0, BLOCK_2)[None, None, :, None, None] + tmp3 = tl.arange(0, BLOCK_3)[None, None, None, :, None] + tmp4 = tl.arange(0, BLOCK_4)[None, None, None, None, :] + + tmp0_1 = tl.arange(0, BLOCK_0)[None, :, None, None, None] + tmp1_0 = tl.arange(0, BLOCK_1)[:, None, None, None, None] + + tmp1_2 = tl.arange(0, BLOCK_1)[None, None, :, None, None] + tmp2_1 = tl.arange(0, BLOCK_2)[None, :, None, None, None] + + tmp2_3 = tl.arange(0, BLOCK_2)[None, None, None, :, None] + tmp3_2 = tl.arange(0, BLOCK_3)[None, None, :, None, None] + + tmp3_4 = tl.arange(0, BLOCK_3)[None, None, None, None, :] + tmp4_3 = tl.arange(0, BLOCK_4)[None, None, None, :, None] + + offsets = pid + tmp0 * STRIDE_0 + tmp1 * STRIDE_1 + tmp2 * STRIDE_2 + tmp3 * STRIDE_3 + tmp4 * STRIDE_4 + masks = (tmp0 < SHAPE_0) & (tmp1 < SHAPE_1) & (tmp2 < SHAPE_2) & (tmp3 < SHAPE_3) & (tmp4 < SHAPE_4) + x_val = tl.load(x_ptr + offsets, masks) + + if PERM == 0: # 1, 0, 2, 3, 4 + ret = tl.trans(x_val, 1, 0, 2, 3, 4) + shape0 = SHAPE_1 + shape1 = SHAPE_0 + shape2 = SHAPE_2 + shape3 = SHAPE_3 + shape4 = SHAPE_4 + elif PERM == 1: # 0, 2, 1, 3, 4 + ret = tl.trans(x_val, 0, 2, 1, 3, 4) + shape0 = SHAPE_0 + shape1 = SHAPE_2 + shape2 = SHAPE_1 + shape3 = SHAPE_3 + shape4 = SHAPE_4 + elif PERM == 2: # 0, 1, 3, 2, 4 + ret = tl.trans(x_val, 0, 1, 3, 2, 4) + shape0 = SHAPE_0 + shape1 = SHAPE_1 + shape2 = SHAPE_3 + shape3 = SHAPE_2 + shape4 = SHAPE_4 + else: # 0, 1, 2, 4, 3 + ret = tl.trans(x_val, 0, 1, 2, 4, 3) + shape0 = SHAPE_0 + shape1 = SHAPE_1 + shape2 = SHAPE_2 + shape3 = SHAPE_4 + shape4 = SHAPE_3 + + s4 = 1 + s3 = s4 * shape4 + s2 = s3 * shape3 + s1 = s2 * shape2 + s0 = s1 * shape1 + + if PERM == 0: # 1, 0, 2, 3, 4 + out_offsets = pid + tmp1_0 * s0 + tmp0_1 * s1 + tmp2 * s2 + tmp3 * s3 + tmp4 * s4 + out_masks = (tmp1_0 < shape0) & (tmp0_1 < shape1) & (tmp2 < shape2) & (tmp3 < shape3) & (tmp4 < shape4) + elif PERM == 1: # 0, 2, 1, 3, 4 + out_offsets = pid + tmp0 * s0 + tmp2_1 * s1 + tmp1_2 * s2 + tmp3 * s3 + tmp4 * s4 + out_masks = (tmp0 < shape0) & (tmp1_2 < shape2) & (tmp2_1 < shape1) & (tmp3 < shape3) & (tmp4 < shape4) + elif PERM == 2: # 0, 1, 3, 2, 4 + out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp3_2 * s2 + tmp2_3 * s3 + tmp4 * s4 + out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp3_2 < shape2) & (tmp2_3 < shape3) & (tmp4 < shape4) + else: # 0, 1, 2, 4, 3 + out_offsets = pid + tmp0 * s0 + tmp1 * s1 + tmp2 * s2 + tmp4_3 * s3 + tmp3_4 * s4 + out_masks = (tmp0 < shape0) & (tmp1 < shape1) & (tmp2 < shape2) & (tmp4_3 < shape3) & (tmp3_4 < shape4) + tl.store(output_ptr + out_offsets, ret, mask=out_masks) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16']) +@pytest.mark.parametrize('perm', [0, 1, 2, 3]) # 4d: support 3 mode; 5d: support 4 mode +def test_trans_4d_5d(shape, dtype, perm): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.randint(low=0, high=2, size=shape, dtype=eval('torch.' + dtype)).npu() + grid = (1,) + if len(shape) == 4: + blocks = list(x.size()) + strides = list(x.stride()) + if perm == 0: # 1, 0, 2, 3; exchange axis 0, 1 + output = torch.empty((shape[1], shape[0], shape[2], shape[3]), dtype=eval('torch.' + dtype)).npu() + ans_4d = torch.permute(x, (1, 0, 2, 3)) + triton_trans_4d[grid](output, x, perm, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans_4d, output) + elif perm == 1: # 0, 2, 1, 3; exchange axis 1, 2 + output = torch.empty((shape[0], shape[2], shape[1], shape[3]), dtype=eval('torch.' + dtype)).npu() + ans_4d = torch.permute(x, (0, 2, 1, 3)) + triton_trans_4d[grid](output, x, perm, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans_4d, output) + elif perm == 2: # 0, 1, 3, 2; exchange axis 2, 3 + output = torch.empty((shape[0], shape[1], shape[3], shape[2]), dtype=eval('torch.' + dtype)).npu() + ans_4d = torch.permute(x, (0, 1, 3, 2)) + triton_trans_4d[grid](output, x, perm, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans_4d, output) + else: + pass + else: + blocks = list(x.size()) + strides = list(x.stride()) + + if perm == 0: # 1, 0, 2, 3, 4; exchange axis 0, 1 + output = torch.empty((shape[1], shape[0], shape[2], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() + ans_5d = torch.permute(x, (1, 0, 2, 3, 4)) + elif perm == 1: # 0, 2, 1, 3, 4; exchange axis 1, 2 + output = torch.empty((shape[0], shape[2], shape[1], shape[3], shape[4]), dtype=eval('torch.' + dtype)).npu() + ans_5d = torch.permute(x, (0, 2, 1, 3, 4)) + elif perm == 2: # 0, 1, 3, 2, 4; exchange axis 2, 3 + output = torch.empty((shape[0], shape[1], shape[3], shape[2], shape[4]), dtype=eval('torch.' + dtype)).npu() + ans_5d = torch.permute(x, (0, 1, 3, 2, 4)) + else: # 0, 1, 2, 4, 3; exchange axis 3, 4 + output = torch.empty((shape[0], shape[1], shape[2], shape[4], shape[3]), dtype=eval('torch.' + dtype)).npu() + ans_5d = torch.permute(x, (0, 1, 2, 4, 3)) + triton_trans_5d[grid](output, x, perm, *blocks, *blocks, *strides) + test_common.validate_cmp(dtype, ans_5d, output) diff --git a/third_party/ascend/examples/generalization_cases/test_umulhi.py b/third_party/ascend/examples/generalization_cases/test_umulhi.py new file mode 100644 index 000000000..37cd6530d --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_umulhi.py @@ -0,0 +1,128 @@ +import logging +import triton +import torch +import pytest +import test_common + +import numpy as np +import triton.language as tl +from test_common import TestUtils + +# inp the two 32 bit signed integers. +@triton.jit +def umulhi_kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.umulhi(x, y) + tl.store(Z + tl.arange(0, N), z) + + +@triton.jit +def triton_umulhi_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = tl.umulhi(x_val, y_val) + tl.store(output_ptr + offsets, ret, mask=masks) + + +# accuracy reference +def umulhi32(a, b): + a_64 = a.astype(np.int64) + b_64 = b.astype(np.int64) + product_64 = a_64 * b_64 + # get the high part + result_high_32 = product_64 >> 32 + return result_high_32.astype(np.int32) + + +@pytest.mark.parametrize('dtype', ['int32']) +@pytest.mark.parametrize('shape', TestUtils.full_shape) +def test_case2(dtype, shape): + N = shape[0] + dtypes = eval('torch.' + dtype) + x = torch.randint(low=0, high=2000, size=shape, dtype=dtypes) + y = torch.randint(low=0, high=2000, size=shape, dtype=dtypes) + xx = x.npu() + yy = y.npu() + z_tri = torch.zeros(size=shape, dtype=dtypes).npu() + umulhi_kernel[(1,)](xx, yy, z_tri, N=N) + + xxx = x.numpy() + yyy = y.numpy() + z_ref = umulhi32(xxx, yyy) + z_ref1 = torch.from_numpy(z_ref).npu() + torch.equal(z_tri, z_ref1) + +invalid_types = [ + 'int8', + 'int16', + 'int64', + 'float16', + 'float32', + 'bfloat16', + 'bool', +] +@pytest.mark.parametrize("dtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype") +def test_umulhi_invalid_dtype_case(dtype): + x0 = test_common.generate_tensor((1,), dtype).npu() + x1 = test_common.generate_tensor((1,), dtype).npu() + + y_cal = torch.zeros((1,), dtype=eval('torch.' + dtype)).npu() + umulhi_kernel[(1,)](x0, x1, y_cal, 1) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int32']) +def test_umulhi_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) + y = torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) + xx = x.npu() + yy = y.npu() + + output = torch.zeros(size=shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + xxx = x.numpy() + yyy = y.numpy() + z = umulhi32(xxx, yyy) + ans = torch.from_numpy(z).npu() + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_umulhi_4d_5d[grid](output, xx, yy, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) \ No newline at end of file diff --git a/third_party/ascend/examples/generalization_cases/test_where.py b/third_party/ascend/examples/generalization_cases/test_where.py new file mode 100644 index 000000000..80b7f124e --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_where.py @@ -0,0 +1,117 @@ +import logging + +import triton +import triton.language as tl +import torch +import pytest +import test_common +from test_common import TestUtils +import math + + +def torch_pointwise(x0, x1): + res = torch.where(x0 < x1, x0, 1) + return res + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + tmp2 = X < Y + ret = tl.where(tmp2, X, 1) + + tl.store(output_ptr + idx, ret) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['bool', 'float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + ans = torch_pointwise(x, y) + output = torch.zeros_like(ans) + + if len(shape) == 1: + fn_npu_[1, 1, shape[0]](output, x, y, z, 1, 1, 1, 1, 1, shape[0]) + elif len(shape) == 2: + if shape[0] > shape[1]: + fn_npu_[1, shape[0], 1](output, x, y, z, 1, 1, shape[1], 1, shape[0], shape[1]) + else: + fn_npu_[1, 1, shape[1]](output, x, y, z, 1, shape[0], 1, 1, shape[0], shape[1]) + elif len(shape) == 3: + if max(shape[0], shape[1], shape[2]) == shape[0]: + fn_npu_[shape[0], 1, 1](output, x, y, z, 1, shape[1], shape[2], shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, shape[2]](output, x, y, z, shape[0], shape[1], 1, shape[0], shape[1], shape[2]) + else: + fn_npu_[1, 1, 1](output, x, y, z, 1, 1, 1, 1, 1, 1) + + test_common.validate_cmp(dtype, ans, output) + + +@triton.jit +def fn_npu_multi_d(output_ptr, x_ptr, y_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIMS: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if DIMS > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if DIMS > 2: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if DIMS > 3: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if DIMS > 4: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + X = tl.load(x_ptr + offsets) + Y = tl.load(y_ptr + offsets) + + tmp2 = X < Y + ret = tl.where(tmp2, X, 1) + + tl.store(output_ptr + offsets, ret) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (4, 2, 8, 4), + (2, 4, 2, 8, 1), + + (4, 3, 8, 1), + (3, 4, 2, 8, 4), +]) +@pytest.mark.parametrize('dtype', + ['float32', 'float16', 'bfloat16', 'int8', 'int16', 'int32', 'int64']) +def test_case_4d_5d(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + + ans = torch_pointwise(x, y) + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + grid = (1, ) + fn_npu_multi_d[grid](output, x, y, *triton_shape, len(shape)) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/generalization_cases/test_xor.py b/third_party/ascend/examples/generalization_cases/test_xor.py new file mode 100644 index 000000000..df7acee9f --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_xor.py @@ -0,0 +1,161 @@ +import logging + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common +from test_common import TestUtils +import math + + +def torch_xor(x0, x1): + return x0 ^ x1 + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, + XNUMEL: tl.constexpr, YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + idx = xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + Y = tl.load(y_ptr + idx) + + ret = X ^ Y + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def triton_xor_4d_5d( + output_ptr, x_ptr, y_ptr, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, + BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, + SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, + STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + x_val = tl.load(x_ptr + offsets, masks) + y_val = tl.load(y_ptr + offsets, masks) + ret = x_val ^ y_val + tl.store(output_ptr + offsets, ret, mask=masks) + + +@pytest.mark.parametrize('shape', TestUtils.full_shape) +@pytest.mark.parametrize('dtype', + ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + logging.debug(f"output.dtype={output.dtype}") + + ans = torch_xor(x, y) + + if len(shape) == 1: + XB = 1 + xnumel = 1 + YB = 1 + ynumel = 1 + ZB = shape[0] + znumel = shape[0] + elif len(shape) == 2: + XB = 1 + xnumel = 1 + YB = shape[0] + ynumel = shape[0] + ZB = shape[1] + znumel = shape[1] + else: + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + if x.numel() * x.element_size() >= 8192: + grid = (1, 1, ZB) + ZB = 1 + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape4d + TestUtils.test_shape5d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_xor_4d_5d(shape, dtype): + logging.log(logging.DEBUG, f"shape = {shape}") + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + + logging.log(logging.DEBUG, f"output.dtype={output.dtype}") + + ans = x ^ y + + blocks = list(x.size()) + strides = list(x.stride()) + while len(blocks) < 5: + blocks.append(1) + strides.append(1) + + grid = (1,) + triton_xor_4d_5d[grid](output, x, y, *blocks, *blocks, *strides) + + test_common.validate_cmp(dtype, ans, output) + +invalid_types = [ + 'float16', + 'float32', + 'bfloat16', +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "unexpected type") +def test_invalid_types(sigtype): + N = 32 + x = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + z = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + output = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + fn_npu_[1, 1, 1](output, x, y, z, 32, 1, 1, 32, 1, 1) + diff --git a/third_party/ascend/examples/generalization_cases/test_xorsum.py b/third_party/ascend/examples/generalization_cases/test_xorsum.py new file mode 100644 index 000000000..002277d64 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_xorsum.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +import math +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common +import functools +from test_common import TestUtils, check_ub_mem_overflow, get_dtype_size + + +# <<<<<<< test_xorsum_1d +def torch_xorsum(tensor, dim=None, keepdim=False): + if dim is None: + result = tensor.flatten()[0] + for x in tensor.flatten()[1:]: + result = result ^ x + return result + else: + assert dim < tensor.dim(), f"Invalid dim {dim} for tensor shape {tensor.shape}" + result = tensor.select(dim, 0) + for i in range(1, tensor.size(dim)): + result = result ^ tensor.select(dim, i) + if keepdim: + result = result.unsqueeze(dim) + return result + + +@triton.jit +def triton_xorsum_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) + tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + xoffset, None) + tmp4 = tl.xor_sum(tmp0, 0) + tl.store(out_ptr1, tmp4, None) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape1d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +def test_xorsum_1d(dtype, shape): + if check_ub_mem_overflow(dtype, shape): + return + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty(1, dtype=eval("torch." + dtype)).npu() + numel = shape[0] + triton_xorsum_1d[1, 1, 1](x0, triton_res, numel, numel) + torch_res = torch_xorsum(x0, dim=0, keepdim=True) + test_common.validate_cmp(dtype, triton_res, torch_res) + + +# >>>>>>> test_xorsum_1d + +# <<<<<<< test_xorsum_2d +@triton.jit +def triton_xorsum_2d(in_ptr0, out_ptr0, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, + NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0, MNUMEL) + nblk_idx = tl.arange(0, NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:, None]) & (nmask[None, :]) + idx = mblk_idx[:, None] * N + nblk_idx[None, :] + x = tl.load(in_ptr0 + idx, mask=mask, other=-float('inf')) + tmp4 = tl.xor_sum(x, dim) + if dim == 0: + tl.store(out_ptr0 + tl.arange(0, N), tmp4, None) + else: + tl.store(out_ptr0 + tl.arange(0, M), tmp4, None) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape2d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +@pytest.mark.parametrize('dim', [0, 1]) +def test_xorsum_2d(dtype, shape, dim): + dtype_size = get_dtype_size(dtype) + if dtype in ['int8', 'int16', 'int32', 'int64']: + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + elif dtype in ['bool']: + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 5): + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + shapex, shapey = shape + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[1 - dim], ], dtype=eval("torch." + dtype)).npu() + triton_xorsum_2d[1, 1, 1](x0, triton_res, dim, shapex, shapey, shapex, shapey) + torch_res = torch_xorsum(x0, dim=dim, keepdim=False) + test_common.validate_cmp(dtype, triton_res, torch_res) + + +# >>>>>>> test_xorsum_2d + +# <<<<<<< test_xorsum_3d +def torch_xorsum_3d(x0, no_reduce_dim): + inp = x0 if x0.device == "cpu" else x0.cpu() + if no_reduce_dim == 0: + return torch_xorsum(torch_xorsum(inp, 1), 1).npu() + elif no_reduce_dim == 1: + return torch_xorsum(torch_xorsum(inp, 0), 1).npu() + elif no_reduce_dim == 2: + return torch_xorsum(torch_xorsum(inp, 0), 0).npu() + else: + assert False, f"no reduce dim not right, no_reduce_dim = {no_reduce_dim}" + + +@triton.jit +def triton_xorsum_3d_0_1(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + x = tl.load(in_ptr + idx) + tmp = tl.xor_sum(x, 0) + ret = tl.xor_sum(tmp, 0) + oidx = zidx + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def triton_xorsum_3d_0_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + x = tl.load(in_ptr + idx) + tmp = tl.xor_sum(x, 0) + ret = tl.xor_sum(tmp, 1) + oidx = yidx + tl.store(out_ptr + oidx, ret) + + +@triton.jit +def triton_xorsum_3d_1_2(in_ptr, out_ptr, + xnumel: tl.constexpr, ynumel: tl.constexpr, znumel: tl.constexpr, + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = xidx[:, None, None] * ynumel * znumel + yidx[None, :, None] * znumel + zidx[None, None, :] + x = tl.load(in_ptr + idx) + tmp = tl.xor_sum(x, 1) + ret = tl.xor_sum(tmp, 1) + oidx = xidx + tl.store(out_ptr + oidx, ret) + + +def triton_xorsum_3d(in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB, no_reduce_dim): + if no_reduce_dim == 0: + triton_xorsum_3d_1_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 1: + triton_xorsum_3d_0_2[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + elif no_reduce_dim == 2: + triton_xorsum_3d_0_1[1, 1, 1](in_ptr, out_ptr, xnumel, ynumel, znumel, XB, YB, ZB) + + +@pytest.mark.parametrize('shape', TestUtils.test_shape3d) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +@pytest.mark.parametrize('no_reduce_dim', [0, 1, 2]) +def test_xorsum_3d(dtype, shape, no_reduce_dim): + x0 = test_common.generate_tensor(shape, dtype).npu() + triton_res = torch.empty([shape[no_reduce_dim], ], dtype=eval("torch." + dtype)).npu() + triton_xorsum_3d(x0, triton_res, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2], no_reduce_dim) + torch_res = torch_xorsum_3d(x0, no_reduce_dim) + test_common.validate_cmp(dtype, triton_res, torch_res) + + +# >>>>>>> test_xorsum_3d + + +# <<<<<<< test_xorsum_4d +@triton.jit +def triton_xorsum_multi_d(in_ptr, out_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, NB: tl.constexpr, DIMS: tl.constexpr, DIM: tl.constexpr, REDUCE_NUMEL: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if DIMS > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if DIMS > 2: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if DIMS > 3: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if DIMS > 4: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + x = tl.load(in_ptr + offsets) + + if DIM is not None: + ret = tl.reshape(tl.xor_sum(x, DIM), REDUCE_NUMEL) + o_offsets = tl.arange(0, REDUCE_NUMEL) + tl.store(out_ptr + o_offsets, ret) + else: + ret = tl.xor_sum(x, DIM) + tl.store(out_ptr, ret) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (4, 2, 8, 4), + (4, 3, 8, 1), +]) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +@pytest.mark.parametrize('dim', [0, 1, 2, 3]) +def test_xorsum_4d(dtype, shape, dim): + dtype_size = get_dtype_size(dtype) + if dtype in ['int8', 'int16', 'int32', 'int64']: + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): + print(f"dtype:{dtype} shape:{shape} mem overflow") + return + + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_xorsum(x0, dim=dim, keepdim=False) + triton_res = torch.empty_like(torch_res, dtype=eval("torch." + dtype)).npu() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None + grid = (1,) + triton_xorsum_multi_d[grid](x0, triton_res, *triton_shape, len(shape), dim, reduce_numel) + test_common.validate_cmp(dtype, triton_res, torch_res) + + +# >>>>>>> test_xorsum_4d + + +# <<<<<<< test_xorsum_5d +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('shape', [ + (2, 4, 2, 8, 4), + (3, 4, 2, 8, 1), +]) +@pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64', 'bool']) +@pytest.mark.parametrize('dim', [0, 1, 2, 3, 4]) +def test_xorsum_5d(dtype, shape, dim): + dtype_size = get_dtype_size(dtype) + if dtype in ['int8', 'int16', 'int32', 'int64']: + if dtype_size * math.prod(shape) >= (TestUtils.ub_size / 3): + print(f"dtype:{dtype} shape:{shape} mem overflow") + return + + x0 = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch_xorsum(x0, dim=dim, keepdim=False) + triton_res = torch.empty_like(torch_res, dtype=eval("torch." + dtype)).npu() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + reduce_numel = math.prod(triton_shape) // triton_shape[dim] if dim is not None else None + grid = (1,) + triton_xorsum_multi_d[grid](x0, triton_res, *triton_shape, len(shape), dim, reduce_numel) + test_common.validate_cmp(dtype, triton_res, torch_res) + + +# >>>>>>> test_xorsum_5d + + +if __name__ == "__main__": + test_xorsum_3d('int8', (3, 3, 3), 0) diff --git a/third_party/ascend/examples/generalization_cases/test_zeros_op.py b/third_party/ascend/examples/generalization_cases/test_zeros_op.py new file mode 100644 index 000000000..18ea0b375 --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_zeros_op.py @@ -0,0 +1,533 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import math +import pytest +import random +import torch +import torch_npu +import triton +import triton.language as tl + +import test_common +from test_common import TestUtils, check_ub_mem_overflow + + +@triton.jit +def fn_npu_int8_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xidx = tl.arange(0, XNUMEL) + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Xmask = xidx < X + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) + ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int8) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xidx = tl.arange(0, XNUMEL) + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Xmask = xidx < X + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) + ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int16) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xidx = tl.arange(0, XNUMEL) + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Xmask = xidx < X + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) + ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int32) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int64_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xidx = tl.arange(0, XNUMEL) + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Xmask = xidx < X + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) + ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int64) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_fp16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xidx = tl.arange(0, XNUMEL) + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Xmask = xidx < X + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) + ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.float16) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_fp32_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xidx = tl.arange(0, XNUMEL) + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Xmask = xidx < X + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) + ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.float32) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_bf16_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xidx = tl.arange(0, XNUMEL) + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Xmask = xidx < X + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) + ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.bfloat16) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_bool_3d(output_ptr, X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xidx = tl.arange(0, XNUMEL) + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Xmask = xidx < X + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Xmask[:, None, None]) & (Ymask[None, :, None]) & (Zmask[None, None, :]) + ret = tl.zeros((XNUMEL, YNUMEL, ZNUMEL), dtype=tl.int1) + oidx = xidx[:, None, None] * Y * Z + yidx[None, :, None] * Z + zidx[None, None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int8_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Ymask[:, None]) & (Zmask[None, :]) + ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int8) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Ymask[:, None]) & (Zmask[None, :]) + ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int16) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Ymask[:, None]) & (Zmask[None, :]) + ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int32) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int64_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Ymask[:, None]) & (Zmask[None, :]) + ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int64) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_fp16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Ymask[:, None]) & (Zmask[None, :]) + ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.float16) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_fp32_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Ymask[:, None]) & (Zmask[None, :]) + ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.float32) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_bf16_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Ymask[:, None]) & (Zmask[None, :]) + ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.bfloat16) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_bool_2d(output_ptr, Y: tl.constexpr, Z: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + yidx = tl.arange(0, YNUMEL) + zidx = tl.arange(0, ZNUMEL) + Ymask = yidx < Y + Zmask = zidx < Z + mask = (Ymask[:, None]) & (Zmask[None, :]) + ret = tl.zeros((YNUMEL, ZNUMEL), dtype=tl.int1) + oidx = yidx[:, None] * Z + zidx[None, :] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int8_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): + zidx = tl.arange(0, ZNUMEL) + Zmask = zidx < Z + mask = (Zmask[:]) + ret = tl.zeros((ZNUMEL,), dtype=tl.int8) + oidx = zidx[:] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int16_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): + zidx = tl.arange(0, ZNUMEL) + Zmask = zidx < Z + mask = (Zmask[:]) + ret = tl.zeros((ZNUMEL,), dtype=tl.int16) + oidx = zidx[:] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int32_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): + zidx = tl.arange(0, ZNUMEL) + Zmask = zidx < Z + mask = (Zmask[:]) + ret = tl.zeros((ZNUMEL,), dtype=tl.int32) + oidx = zidx[:] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int64_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): + zidx = tl.arange(0, ZNUMEL) + Zmask = zidx < Z + mask = (Zmask[:]) + ret = tl.zeros((ZNUMEL,), dtype=tl.int64) + oidx = zidx[:] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_fp16_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): + zidx = tl.arange(0, ZNUMEL) + Zmask = zidx < Z + mask = (Zmask[:]) + ret = tl.zeros((ZNUMEL,), dtype=tl.float16) + oidx = zidx[:] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_fp32_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): + zidx = tl.arange(0, ZNUMEL) + Zmask = zidx < Z + mask = (Zmask[:]) + ret = tl.zeros((ZNUMEL,), dtype=tl.float32) + oidx = zidx[:] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_bf16_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): + zidx = tl.arange(0, ZNUMEL) + Zmask = zidx < Z + mask = (Zmask[:]) + ret = tl.zeros((ZNUMEL,), dtype=tl.bfloat16) + oidx = zidx[:] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_bool_1d(output_ptr, Z: tl.constexpr, ZNUMEL: tl.constexpr): + zidx = tl.arange(0, ZNUMEL) + Zmask = zidx < Z + mask = (Zmask[:]) + ret = tl.zeros((ZNUMEL,), dtype=tl.int1) + oidx = zidx[:] + tl.store(output_ptr + oidx, ret, mask=mask) + + +@triton.jit +def fn_npu_int8_0d(output_ptr, N: tl.constexpr): + zero = tl.zeros((), dtype=tl.int8) + tl.store(output_ptr, zero) + + +@triton.jit +def fn_npu_int16_0d(output_ptr, N: tl.constexpr): + zero = tl.zeros((), dtype=tl.int16) + tl.store(output_ptr, zero) + + +@triton.jit +def fn_npu_int32_0d(output_ptr, N: tl.constexpr): + zero = tl.zeros((), dtype=tl.int32) + tl.store(output_ptr, zero) + + +@triton.jit +def fn_npu_int64_0d(output_ptr, N: tl.constexpr): + zero = tl.zeros((), dtype=tl.int64) + tl.store(output_ptr, zero) + + +@triton.jit +def fn_npu_fp16_0d(output_ptr, N: tl.constexpr): + zero = tl.zeros((), dtype=tl.float16) + tl.store(output_ptr, zero) + + +@triton.jit +def fn_npu_fp32_0d(output_ptr, N: tl.constexpr): + zero = tl.zeros((), dtype=tl.float32) + tl.store(output_ptr, zero) + + +@triton.jit +def fn_npu_bf16_0d(output_ptr, N: tl.constexpr): + zero = tl.zeros((), dtype=tl.bfloat16) + tl.store(output_ptr, zero) + + +@triton.jit +def fn_npu_bool_0d(output_ptr, N: tl.constexpr): + zero = tl.zeros((), dtype=tl.int1) + tl.store(output_ptr, zero) + + +test_dtype = ['int8', 'int16', 'int32', 'int64', 'float16', 'float32', 'bfloat16', 'bool'] +test_shape0d = [()] +test_shape1d = TestUtils.test_shape1d +test_shape2d = TestUtils.test_shape2d +test_shape3d = TestUtils.test_shape3d + +# 定义 dtype 到 (test_func, test_sigtype) 的映射 +dtype_mapping3d = { + 'int8': (fn_npu_int8_3d, torch.int8), + 'int16': (fn_npu_int16_3d, torch.int16), + 'int32': (fn_npu_int32_3d, torch.int32), + 'int64': (fn_npu_int64_3d, torch.int64), + 'float16': (fn_npu_fp16_3d, torch.float16), + 'float32': (fn_npu_fp32_3d, torch.float32), + 'bfloat16': (fn_npu_bf16_3d, torch.bfloat16), + 'bool': (fn_npu_bool_3d, torch.bool), +} +dtype_mapping2d = { + 'int8': (fn_npu_int8_2d, torch.int8), + 'int16': (fn_npu_int16_2d, torch.int16), + 'int32': (fn_npu_int32_2d, torch.int32), + 'int64': (fn_npu_int64_2d, torch.int64), + 'float16': (fn_npu_fp16_2d, torch.float16), + 'float32': (fn_npu_fp32_2d, torch.float32), + 'bfloat16': (fn_npu_bf16_2d, torch.bfloat16), + 'bool': (fn_npu_bool_2d, torch.bool), +} +dtype_mapping1d = { + 'int8': (fn_npu_int8_1d, torch.int8), + 'int16': (fn_npu_int16_1d, torch.int16), + 'int32': (fn_npu_int32_1d, torch.int32), + 'int64': (fn_npu_int64_1d, torch.int64), + 'float16': (fn_npu_fp16_1d, torch.float16), + 'float32': (fn_npu_fp32_1d, torch.float32), + 'bfloat16': (fn_npu_bf16_1d, torch.bfloat16), + 'bool': (fn_npu_bool_1d, torch.bool), +} +dtype_mapping0d = { + 'int8': (fn_npu_int8_0d, torch.int8), + 'int16': (fn_npu_int16_0d, torch.int16), + 'int32': (fn_npu_int32_0d, torch.int32), + 'int64': (fn_npu_int64_0d, torch.int64), + 'float16': (fn_npu_fp16_0d, torch.float16), + 'float32': (fn_npu_fp32_0d, torch.float32), + 'bfloat16': (fn_npu_bf16_0d, torch.bfloat16), + 'bool': (fn_npu_bool_0d, torch.bool), +} + +# 生成测试用例 +testlist = [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape0d + for func, dtype in [dtype_mapping0d[sigtype]] # 直接解包映射结果 +] + +testlist += [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape1d + for func, dtype in [dtype_mapping1d[sigtype]] # 直接解包映射结果 +] + +testlist += [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape2d + for func, dtype in [dtype_mapping2d[sigtype]] # 直接解包映射结果 +] + +testlist += [ + (func, sigtype, dtype, shape) + for sigtype in test_dtype + for shape in test_shape3d + for func, dtype in [dtype_mapping3d[sigtype]] # 直接解包映射结果 +] + + +@pytest.mark.parametrize('testfunc, sigtype, dtype, shape', testlist) +def test_npu(testfunc, sigtype, dtype, shape): + if check_ub_mem_overflow(sigtype, shape): + pytest.skip(f"dtype:{sigtype} shape:{shape} mem overflow") + x = 0 + output = 0 + if len(shape) == 3: + x = torch.full((shape[0], shape[1], shape[2]), 0, dtype=dtype).npu() + output = torch.randint(1, (shape[0], shape[1], shape[2]), dtype=dtype).npu() + testfunc[(1, 1, 1)](output, shape[0], shape[1], shape[2], shape[0], shape[1], shape[2]) + if len(shape) == 2: + x = torch.full((shape[0], shape[1]), 0, dtype=dtype).npu() + output = torch.randint(1, (shape[0], shape[1]), dtype=dtype).npu() + shape0 = shape[0] + shape1 = shape[1] + if x.numel() * x.element_size() >= 8192: + grid = (shape0, 1, 1) + shape0 = 1 + else: + grid = (1, 1, 1) + testfunc[grid](output, shape0, shape1, shape0, shape1) + if len(shape) == 1: + x = torch.full((shape[0],), 0, dtype=dtype).npu() + output = torch.randint(1, (shape[0],), dtype=dtype).npu() + testfunc[1, 1, 1](output, shape[0], shape[0]) + if len(shape) == 0: + output = torch.randint(1, size=shape, dtype=dtype).npu() + x = torch.zeros_like(output) + testfunc[(1,)](output_ptr=output, N=1) + test_common.validate_cmp(sigtype, output, x) + + +@triton.jit +def fn_npu_multi_d(output_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, + NB: tl.constexpr): + dtype = output_ptr.type.element_ty + + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + if (YB * ZB * MB * NB) == 1: + ret = tl.zeros((XB,), dtype=dtype) + elif (ZB * MB * NB) == 1: + ret = tl.zeros((XB, YB), dtype=dtype) + elif (MB * NB) == 1: + ret = tl.zeros((XB, YB, ZB), dtype=dtype) + elif NB == 1: + ret = tl.zeros((XB, YB, ZB, MB), dtype=dtype) + else: + ret = tl.zeros((XB, YB, ZB, MB, NB), dtype=dtype) + + tl.store(output_ptr + offsets, ret) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('param_list', + [ + ('float32', (4, 2, 16, 16)), + ('float32', (2, 4, 2, 16, 16)), + + ('float32', (4, 2, 16, 16)), + ('float32', (2, 4, 2, 16, 16)), + + ('float32', (4, 2, 16, 16)), + ('float32', (2, 4, 2, 16, 16)), + ] + ) +def test_case_4d_5d(param_list): + dtype, shape = param_list + if check_ub_mem_overflow(dtype, shape): + pytest.skip(f"dtype:{dtype} shape:{shape} mem overflow") + y_ref = torch.full(shape, 0, dtype=eval('torch.' + dtype)).npu() + print(f"y_ref = {torch.flatten(y_ref)[0:4]}") + + y_cal = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + fn_npu_multi_d[(1,)](y_cal, *triton_shape) + print(f"y_cal = {torch.flatten(y_cal)[0:4]}") + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/generalization_cases/test_zeroslike.py b/third_party/ascend/examples/generalization_cases/test_zeroslike.py new file mode 100644 index 000000000..92827e5de --- /dev/null +++ b/third_party/ascend/examples/generalization_cases/test_zeroslike.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import logging +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +import test_common +from test_common import TestUtils, check_ub_mem_overflow + + +@triton.jit +def fn_npu_0d(output_ptr, x_ptr, YB: tl.constexpr): + yidx = tl.arange(0, YB) + + idx = yidx + + X = tl.load(x_ptr + idx) + + ret = tl.zeros_like(X) + + oidx = yidx + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_1d(output_ptr, x_ptr, YB: tl.constexpr): + yidx = tl.arange(0, YB) + + idx = yidx + + X = tl.load(x_ptr + idx) + + ret = tl.zeros_like(X) + + oidx = yidx + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_2d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr): + pid = tl.program_id(0) + yidx = tl.arange(0, YB)[:, None] + pid * YB + zidx = tl.arange(0, ZB)[None, :] + + idx = yidx * ZB + zidx + + X = tl.load(x_ptr + idx) + + ret = tl.zeros_like(X) + + oidx = yidx * ZB + zidx + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_3d(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + yidx = tl.arange(0, YB)[:, None, None] * ZB * KB + zidx = tl.arange(0, ZB)[None, :, None] * KB + kidx = tl.arange(0, KB)[None, None, :] + + idx = yidx + zidx + kidx + + X = tl.load(x_ptr + idx) + + ret = tl.zeros_like(X) + + oidx = yidx + zidx + kidx + + tl.store(output_ptr + oidx, ret) + + +test_shape0d = [()] +testlist = test_shape0d + TestUtils.test_shape1_2_3d + + +@pytest.mark.parametrize('shape', testlist) +@pytest.mark.parametrize('dtype', TestUtils.dtype_list) +def test_npu(shape, dtype): + logging.debug(f'dtype:{dtype} shape:{shape}') + if check_ub_mem_overflow(dtype, shape): + return + x = torch.full(shape, 0, dtype=eval('torch.' + dtype)).npu() + triton_res = torch.empty(shape, dtype=eval('torch.' + dtype)).npu() + torch_res = x + + if len(shape) == 0: + fn_npu_0d[1, 1, 1](triton_res, x, 1) + elif len(shape) == 1: + fn_npu_1d[1, 1, 1](triton_res, x, shape[0]) + elif len(shape) == 2: + fn_npu_2d[shape[0], 1, 1](triton_res, x, 1, shape[1]) + elif len(shape) == 3: + fn_npu_3d[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + + test_common.validate_cmp(dtype, triton_res, torch_res) + + +@triton.jit +def fn_npu_multi_d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, MB: tl.constexpr, + NB: tl.constexpr): + offsets = tl.arange(0, XB) * (YB * ZB * MB * NB) + if (YB * ZB * MB * NB) > 1: + offsets = offsets[:, None] + tl.arange(0, YB)[None, :] * (ZB * MB * NB) + if (ZB * MB * NB) > 1: + offsets = offsets[:, :, None] + tl.arange(0, ZB)[None, None, :] * (MB * NB) + if (MB * NB) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, MB)[None, None, None, :] * NB + if NB > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, NB)[None, None, None, None, :] + + X = tl.load(x_ptr + offsets) + ret = tl.zeros_like(X) + + tl.store(output_ptr + offsets, ret) + + +@pytest.mark.shape_4d_5d +@pytest.mark.parametrize('param_list', + [ + ('float32', (4, 2, 16, 16)), + ('float32', (2, 4, 2, 16, 16)), + + ('float32', (4, 2, 16, 16)), + ('float32', (2, 4, 2, 16, 16)), + + ('float32', (4, 2, 16, 16)), + ('float32', (2, 4, 2, 16, 16)), + ] + ) +def test_case_4d_5d(param_list): + dtype, shape = param_list + if check_ub_mem_overflow(dtype, shape): + return + x0 = test_common.generate_tensor(shape, dtype) + y_ref = torch.zeros_like(x0, dtype=eval('torch.' + dtype)).npu() + print(f"y_ref = {torch.flatten(y_ref)[0:4]}") + y_cal = torch.ones(shape, dtype=eval('torch.' + dtype)).npu() + + triton_shape = [*shape] + while len(triton_shape) < 5: + triton_shape.append(1) + fn_npu_multi_d[(1,)](y_cal, x0, *triton_shape) + print(f"y_cal = {torch.flatten(y_cal)[0:4]}") + test_common.validate_cmp(dtype, y_cal, y_ref) + + +if __name__ == "__main__": + for dtype in TestUtils.dtype_list: + for shape in [(37,), (37, 3), (1, 22, 39)]: + test_npu(shape, dtype) diff --git a/third_party/ascend/examples/inductor_cases/run_inductor_test.sh b/third_party/ascend/examples/inductor_cases/run_inductor_test.sh new file mode 100644 index 000000000..4a0977f15 --- /dev/null +++ b/third_party/ascend/examples/inductor_cases/run_inductor_test.sh @@ -0,0 +1,151 @@ +inductor_skip_list=( + "test_check_accuracy.py" + "test_debug_msg.py" + "test_embedding.py" + "test_force_fallback.py" + "test_foreach_add.py" + "test_geometric.py" + "test_lazy_register.py" +) + +TEST_inductor="${WORKSPACE}/ascend/examples/inductor_cases" +# 定义统计文件路径 +SUMMARY_FILE="${WORKSPACE}/ascend/examples/summary.txt" + +# install daily torch_npu +current_date=$(date +%Y%m%d) +wget https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v2.6.0/${current_date}.3/pytorch_v2.6.0_py311.tar.gz +tar -zxvf pytorch_v2.6.0_py311.tar.gz +pip install *.dev${current_date}-cp311-cp311-manylinux_2_28_aarch64.whl + +# remove inductor and triton cache +if [ -d /tmp/torchinductor_* ];then + rm -rf /tmp/torchinductor_* +fi + +if [ -d ~/.triton/dump ];then + rm -rf ~/.triton/dump +fi + +if [ -d ~/.triton/cache ];then + rm -rf ~/.triton/cache +fi + +cd ${TEST_inductor} +git init +git remote add origin http://gitee.com/ascend/pytorch.git +git config core.sparsecheckout true +echo "test/_inductor" >> .git/info/sparse-checkout +git pull origin v2.6.0:master +TEST_inductor_cases_path="${TEST_inductor}/test/_inductor" +cd ${TEST_inductor_cases_path} +export PYTHONPATH="${PYTHONPATH}:${TEST_inductor_cases_path}" + +# 记录跳过的测试用例 +echo -e "\n======= Inductor 测试跳过的用例 =======" >> $SUMMARY_FILE +for skip_case in ${inductor_skip_list[@]}; +do + if [ -e "${TEST_inductor_cases_path}/${skip_case}" ];then + echo "跳过测试用例: ${skip_case}" | tee -a $SUMMARY_FILE + mv ${skip_case} "${skip_case}_skip" + fi +done + +# 创建临时日志目录 +LOG_DIR=$(mktemp -d) +INDUCTOR_CASE_LOG_FILE="$LOG_DIR/test_inductor_case_$(date +%Y%m%d).log" + +# 记录测试开始时间 +start_time=$(date +"%Y-%m-%d %H:%M:%S") + +# 执行测试并生成JUnit报告 +pytest -n 16 --dist=loadfile . \ + --junitxml="$LOG_DIR/results.xml" \ + 2>&1 | tee "$INDUCTOR_CASE_LOG_FILE" + +# 解析统计信息 +# 使用Python解析JUnit XML报告 +python3 -c " +import xml.etree.ElementTree as ET +import os + +xml_file = '$LOG_DIR/results.xml' +if not os.path.exists(xml_file): + print('JUnitXML report not found:', xml_file) + exit(1) + +tree = ET.parse(xml_file) +root = tree.getroot() + +total_tests = 0 +passed_tests = 0 +failed_tests = 0 +skipped_tests = 0 +error_tests = 0 + +# 遍历所有testsuite +for testsuite in root.findall('testsuite'): + total_tests += int(testsuite.get('tests', 0)) + skipped_tests += int(testsuite.get('skipped', 0)) + error_tests += int(testsuite.get('errors', 0)) + failed_tests += int(testsuite.get('failures', 0)) + +# 计算通过用例数 +passed_tests = total_tests - skipped_tests - error_tests - failed_tests + +# 输出统计信息 +print(f'total_tests={total_tests}') +print(f'passed_tests={passed_tests}') +print(f'failed_tests={failed_tests}') +print(f'skipped_tests={skipped_tests}') +print(f'error_tests={error_tests}') +" > $LOG_DIR/stats.tmp + +# 加载统计结果 +source $LOG_DIR/stats.tmp +rm $LOG_DIR/stats.tmp + +# 计算测试持续时间 +end_time=$(date +"%Y-%m-%d %H:%M:%S") +duration=$(( $(date -d "$end_time" +%s) - $(date -d "$start_time" +%s) )) +duration_str=$(printf "%02dh %02dm %02ds" $((duration/3600)) $(((duration%3600)/60)) $((duration%60))) + +# 计算通过率 +if [ $total_tests -gt 0 ]; then + pass_rate=$(( 100 * passed_tests / total_tests )) +else + pass_rate=0 +fi + +# 生成统计信息摘要 +stats_summary=" +inductor 测试用例结果摘要: +------------------------ +开始时间: $start_time +结束时间: $end_time +总耗时: $duration_str +------------------------ +总用例数: $total_tests +成功用例: $passed_tests +失败用例: $failed_tests +跳过用例: $skipped_tests +错误用例: $error_tests +------------------------ +通过率: ${pass_rate}% (成功/总数) +并行度: 16个进程 +------------------------ +" + +# 输出统计信息到控制台 +echo "$stats_summary" + +# 追加统计信息到summary.txt +echo "$stats_summary" >> $SUMMARY_FILE + +# 保存原始日志文件 +cp "$INDUCTOR_CASE_LOG_FILE" "/home/daily_log/" + +# 清理临时文件 +rm -rf "$LOG_DIR" + +echo "测试统计信息已追加到: $SUMMARY_FILE" \ No newline at end of file diff --git a/third_party/ascend/examples/model_cases/deberta.py b/third_party/ascend/examples/model_cases/deberta.py new file mode 100644 index 000000000..9e6d4cdfd --- /dev/null +++ b/third_party/ascend/examples/model_cases/deberta.py @@ -0,0 +1,47 @@ +import logging +import os + +import torch +import torch_npu +import torch_npu._inductor + +from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification + +os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + +logging.basicConfig(level=logging.DEBUG) + +torch.npu.config.allow_internal_format = False +torch.manual_seed(0) +torch.npu.manual_seed(0) +tokenizer = AutoTokenizer.from_pretrained("./microsoft/deberta-v3-large") + +sample_texts = ["This is a positive example.", "This might be negative."] * 128 + +model_ = AutoModelForTokenClassification.from_pretrained("./microsoft/deberta-v3-large", device_map="npu:0") +model_.eval() + +inputs = tokenizer( + sample_texts, + max_length=512, + padding="longest", + truncation=True, + return_tensors="pt", + add_special_tokens=True +).to("npu:0") + + +def model(**model_inputs): + with torch.no_grad(): + return model_(**model_inputs).logits + +y = model(**inputs) +logging.info("result eager: " + str(torch.flatten(y)[:100])) + +model_compiled = torch.compile(model_) + +z = model_compiled(**inputs) +logging.info("result compiled: " + str(torch.flatten(z)[:100])) + +torch.testing.assert_close(y, z, atol=1e-4, rtol=1e-4) +logging.info("deberta accuracy check pass!") \ No newline at end of file diff --git a/third_party/ascend/examples/model_cases/llama.py b/third_party/ascend/examples/model_cases/llama.py new file mode 100644 index 000000000..44e208777 --- /dev/null +++ b/third_party/ascend/examples/model_cases/llama.py @@ -0,0 +1,37 @@ +import logging +import os + +import torch +import torch_npu +import torch_npu._inductor + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig + +os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + +logging.basicConfig(level=logging.DEBUG) + +torch.npu.config.allow_internal_format = False +torch.manual_seed(0) +torch.npu.manual_seed(0) +tokenizer = AutoTokenizer.from_pretrained("./Meta-Llama-3-8B") + +inputs = tokenizer("Hello, how to make China great again?", return_tensors="pt").to("npu:0") +model_ = AutoModelForCausalLM.from_pretrained("./Meta-Llama-3-8B", device_map="npu:0", _attn_implementation="eager") +model_.eval() + + +def model(**model_inputs): + with torch.no_grad(): + return model_(**model_inputs).logits + +y = model(**inputs) +logging.info("result eager: " + str(torch.flatten(y)[:100])) + +model_compiled = torch.compile(model_) + +z = model_compiled(**inputs) +logging.info("result compiled: " + str(torch.flatten(z)[:100])) + +torch.testing.assert_close(y, z, atol=1e-4, rtol=1e-4) +logging.info("llama accuracy check pass!") \ No newline at end of file diff --git a/third_party/ascend/examples/model_cases/qwen.py b/third_party/ascend/examples/model_cases/qwen.py new file mode 100644 index 000000000..651072791 --- /dev/null +++ b/third_party/ascend/examples/model_cases/qwen.py @@ -0,0 +1,37 @@ +import logging +import os + +import torch +import torch_npu +import torch_npu._inductor + +from transformers import AutoTokenizer, AutoModelForCausalLM + +os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + +logging.basicConfig(level=logging.DEBUG) + +torch.npu.config.allow_internal_format = False +torch.manual_seed(0) +torch.npu.manual_seed(0) +tokenizer = AutoTokenizer.from_pretrained("./Qwen2.5-0.5B-Instruct") + +inputs = tokenizer("Hello, how to make China great again?", return_tensors="pt").to("npu:0") +model_ = AutoModelForCausalLM.from_pretrained("./Qwen2.5-0.5B-Instruct", device_map="npu:0") +model_.eval() + + +def model(**model_inputs): + with torch.no_grad(): + return model_(**model_inputs).logits + +y = model(**inputs) +logging.info("result eager: " + str(torch.flatten(y)[:100])) + +model_compiled = torch.compile(model_) + +z = model_compiled(**inputs) +logging.info("result compiled: " + str(torch.flatten(z)[:100])) + +torch.testing.assert_close(y, z, atol=1e-4, rtol=1e-4) +logging.info("qwen accuracy check pass!") \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d.py b/third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d.py new file mode 100644 index 000000000..69abff3bc --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import torch_npu + +#number of block per aiv core +NBLOCKS = 32 +#constants +TL_DTYPE_ATTN = tl.bfloat16 + +@triton.jit +def tl_fn_forward_update( output_ptr0, output_ptr1, output_ptr2, prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, + cur_softmax_sum, S:tl.constexpr,B:tl.constexpr,H:tl.constexpr,N:tl.constexpr, E:tl.constexpr, + CUR_SM_N_STRIDE:tl.constexpr, PREV_SM_N_STRIDE:tl.constexpr, + S_NBLOCKS:tl.constexpr, S_SUB:tl.constexpr, BRC_SIZE:tl.constexpr): + + S_E_SUB_SIZE:tl.constexpr = S_SUB * E + D:tl.constexpr = H // N + S_BLOCK:tl.constexpr = (S + S_NBLOCKS - 1) // S_NBLOCKS + S_NSUB:tl.constexpr = (S_BLOCK + S_SUB - 1) // S_SUB + + D_SUB:tl.constexpr = E * BRC_SIZE + D_NLOOP:tl.constexpr = (D + D_SUB - 1) // D_SUB + + LOOP_COUNT : tl.constexpr = (S_BLOCK * B * N + S_SUB - 1) // S_SUB + block_idx = tl.program_id(0) + s_block_offset = block_idx % S_NBLOCKS * S_BLOCK + + for loop_index in range(LOOP_COUNT): + b = loop_index // (N * S_NSUB) + n = (loop_index // S_NSUB) % N + s_loop_offset = loop_index % S_NSUB * S_SUB + s = s_block_offset + s_loop_offset + tl.arange(0, S_SUB) #index on S axis + se_offset = E * (s_block_offset + s_loop_offset) + tl.arange(0, S_E_SUB_SIZE) + #assume no slice on N axis + offsets_prev = PREV_SM_N_STRIDE * (b * N + n) + se_offset + offsets_cur = CUR_SM_N_STRIDE * (b * N + n) + se_offset + mask0 = None + prev_softmax_max_local = tl.load(prev_softmax_max + offsets_prev, mask0) + cur_softmax_max_local = tl.load(cur_softmax_max + offsets_cur, mask0) + softmax_max = tl.maximum(prev_softmax_max_local, cur_softmax_max_local) + prev_scale = tl.exp(prev_softmax_max_local - softmax_max) + cur_scale = tl.exp(cur_softmax_max_local - softmax_max) + prev_softmax_sum_local = tl.load(prev_softmax_sum + offsets_prev, mask0) + cur_softmax_sum_local = tl.load(cur_softmax_sum + offsets_cur, mask0) + prev_softmax_sum_scaled = prev_softmax_sum_local * prev_scale + cur_softmax_sum_scaled = cur_softmax_sum_local * cur_scale + softmax_sum = prev_softmax_sum_scaled + cur_softmax_sum_scaled + prev_out_scale_local = prev_softmax_sum_scaled / softmax_sum + cur_out_scale_local = cur_softmax_sum_scaled / softmax_sum + prev_out_scale_out = prev_out_scale_local + cur_out_scale_out = cur_out_scale_local + mask1 = None + tl.store(output_ptr1 + offsets_cur, softmax_max, mask0) + tl.store(output_ptr2 + offsets_cur, softmax_sum, mask0) + prev_out_scale = prev_out_scale_local + prev_out_scale = prev_out_scale.reshape(S_SUB,1,E)#(s_sub,8)->(1,1,s_sub,1,8) + cur_out_scale = cur_out_scale_local + cur_out_scale = cur_out_scale.reshape(S_SUB,1,E) + for d_index in range(D_NLOOP): + # (s,b,h) -> (s,b,n*d) -> (s,b,n,d) + d = d_index * D_SUB + tl.arange(0, D_SUB) + offsets2 = s[:,None] * B * H + b * H + n * D + d[None,:] + mask2 = None + prev_attn_out_local = tl.load(prev_attn_out + offsets2, mask2) + cur_attn_out_local = tl.load(cur_attn_out + offsets2, mask2) + prev_attn_out_local = prev_attn_out_local.to(tl.float32) + cur_attn_out_local = cur_attn_out_local.to(tl.float32) + prev_attn_out_local = prev_attn_out_local.reshape(S_SUB,BRC_SIZE,E) + cur_attn_out_local = cur_attn_out_local.reshape(S_SUB,BRC_SIZE,E) + prev_out_scale_brc = prev_out_scale + prev_out_scale_brc = prev_out_scale_brc.broadcast_to(S_SUB,BRC_SIZE,E) + cur_out_scale_brc = cur_out_scale + cur_out_scale_brc = cur_out_scale_brc.broadcast_to(S_SUB,BRC_SIZE,E) + attn_out_local = prev_attn_out_local * prev_out_scale_brc + cur_attn_out_local * cur_out_scale_brc + attn_out_local = attn_out_local.reshape(S_SUB, D_SUB) + attn_out_local = attn_out_local.to(TL_DTYPE_ATTN) + tl.store(output_ptr0 + offsets2, attn_out_local, mask2) + + +#target_name = torch.npu.get_device_name(torch.npu.current_device()) +guards = {"dummy":None} +def forward_update_triton(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, + cur_softmax_sum): + + # size of sub_block in one block + S_SUB_SIZE = 64 + BROARDCAST_SIZE = 8 + + (S,B,H) = cur_attn_out.shape + N = prev_softmax_max.shape[1] + E = prev_softmax_max.shape[3] + D = S//H + D_SUB = E * BROARDCAST_SIZE + #(b,n,s,8) + PREV_SM_N_STRIDE = prev_softmax_max.stride()[1] + CUR_SM_N_STRIDE = cur_softmax_max.stride()[1] + + GUARD = (S % NBLOCKS == 0 and H % N == 0 and (H // N) % (D_SUB) == 0) + + if (not GUARD ) : + print(f"parameter does not meet compiling GUARD , fallback to eager foward_update \ + (S,H,N,D,D_SUB,NBLOCKS):{S},{H},{N},{D},{D_SUB},{NBLOCKS}") + + + + + org_dtype = cur_attn_out.dtype + device_id = cur_attn_out.device.index + device = "npu:" +str(device_id) + + softmax_max = torch.empty_strided(cur_softmax_max.shape, cur_softmax_max.stride(), + dtype=cur_softmax_max.dtype, device = device) + softmax_sum = torch.empty_strided(cur_softmax_sum.shape, cur_softmax_sum.stride(), + dtype=cur_softmax_max.dtype, device = device) + attn_out = torch.empty_strided(cur_attn_out.shape, cur_attn_out.stride(), dtype=cur_attn_out.dtype, device = device) + + tl_fn_forward_update[NBLOCKS,1,1](attn_out, softmax_max, softmax_sum, prev_attn_out, prev_softmax_max, + prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum, S=S, B=B, H=H, N=N, E=E, + CUR_SM_N_STRIDE=CUR_SM_N_STRIDE, PREV_SM_N_STRIDE=PREV_SM_N_STRIDE, S_NBLOCKS=NBLOCKS, S_SUB=S_SUB_SIZE, BRC_SIZE=BROARDCAST_SIZE, debug=True) + + return attn_out, softmax_max, softmax_sum + + +@triton.jit +def tl_fn_backward_update( dq, dk, dv, cur_dq, cur_dk, cur_dv, qnumel:tl.constexpr, knumel:tl.constexpr, + XBLOCK:tl.constexpr, SIMD_SIZE:tl.constexpr): + + block_idx = tl.program_id(0) + block_offset = block_idx * XBLOCK + LOOP_COUNT : tl.constexpr = (XBLOCK + SIMD_SIZE - 1) // SIMD_SIZE + for loop_index in range(LOOP_COUNT): + loop_offset = block_offset + loop_index * SIMD_SIZE + tl.arange(0, SIMD_SIZE) + mask0 = loop_offset < qnumel + tmp0 = tl.load(dq + loop_offset, mask0).to(tl.float32) + tmp1 = tl.load(cur_dq + loop_offset, mask0).to(tl.float32) + tmp0 = tmp1 + tmp0 + tl.store(dq + loop_offset, tmp0.to(tl.bfloat16), mask0) + + mask1 = loop_offset < knumel + tmp2 = tl.load(dk + loop_offset, mask1).to(tl.float32) + tmp3 = tl.load(cur_dk + loop_offset, mask1).to(tl.float32) + tmp2 = tmp2 + tmp3 + + tmp4 = tl.load(dv + loop_offset, mask1).to(tl.float32) + tmp5 = tl.load(cur_dv + loop_offset, mask1).to(tl.float32) + tmp4 = tmp4 + tmp5 + + tl.store(dk + loop_offset, tmp2.to(tl.bfloat16), mask1) + tl.store(dv + loop_offset, tmp4.to(tl.bfloat16), mask1) + + +def backward_update_triton( dq, dk, dv, cur_dq, cur_dk, cur_dv ) : + # parameters need auto-tune. + SIMD_SIZE = 4*1024 + + xblock = max(dq.numel(), dk.numel()) // NBLOCKS + + tl_fn_backward_update[40,1,1](dq, dk, dv, cur_dq, cur_dk, cur_dv, dq.numel(), dk.numel(), xblock, SIMD_SIZE, debug=True) + diff --git a/third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d_la.py b/third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d_la.py new file mode 100644 index 000000000..0ec2907c9 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d_la.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import torch +import torch_npu + +#number of block per aiv core +NBLOCKS = 32 +# size of sub_block in one block +# reduce the size from 64 to 32 due to UB overflow +S_SUB_SIZE = 32 + +@triton.jit +def tl_fn_forward_update_la( prev_attn_out, prev_softmax_log_max_sum, cur_attn_out, cur_softmax_log_max_sum, + B:tl.constexpr,N:tl.constexpr,S:tl.constexpr,D:tl.constexpr, PREV_ATTN_NSTRIDE:tl.constexpr, PREV_SOFTMAX_NSTRIDE:tl.constexpr, + CUR_ATTN_NSTRIDE:tl.constexpr, CUR_SOFTMAX_NSTRIDE:tl.constexpr, S_NBLOCKS:tl.constexpr, S_SUB:tl.constexpr): + + S_BLOCK:tl.constexpr = (S + S_NBLOCKS - 1) // S_NBLOCKS + S_NSUB:tl.constexpr = (S_BLOCK + S_SUB - 1) // S_SUB + LOOP_COUNT : tl.constexpr = (S_BLOCK * B * N + S_SUB - 1) // S_SUB + block_idx = tl.program_id(0) + + s_block_start = block_idx * S_BLOCK + # assuming S stride is D, if D is not contiguous, need to use 2-d offset + SIMD_SIZE:tl.constexpr = S_SUB * D + + for loop_index in range(LOOP_COUNT): + b = loop_index // (N * S_NSUB) + n = (loop_index // S_NSUB) % N + s_loop_start = (loop_index % S_NSUB) * S_SUB + s = s_block_start + s_loop_start + + sd_offsets = D * s + tl.arange(0, SIMD_SIZE) + s1_offsets = s + tl.arange(0, S_SUB) + + mask0 = None + softmax_offsets = PREV_SOFTMAX_NSTRIDE * (b * N + n) + s1_offsets + prev_softmax_local = tl.load(prev_softmax_log_max_sum + softmax_offsets, mask0) + offsets = CUR_SOFTMAX_NSTRIDE * (b * N + n) + s1_offsets + cur_softmax_local = tl.load(cur_softmax_log_max_sum + offsets, mask0) + + attn_offsets = PREV_ATTN_NSTRIDE * (b * N + n) + sd_offsets + prev_attn_local = tl.load(prev_attn_out + attn_offsets, mask0) + offsets = CUR_ATTN_NSTRIDE * (b * N + n) + sd_offsets + cur_attn_local = tl.load(cur_attn_out + offsets, mask0) + + tmp0 = tl.exp(cur_softmax_local) + tmp1 = tl.exp(prev_softmax_local) + softmax_log_max_sum = tl.log(tmp0 + tmp1) + tmp2 = (prev_softmax_local - softmax_log_max_sum).reshape(S_SUB,1).broadcast_to(S_SUB, D) + tmp3 = (cur_softmax_local- softmax_log_max_sum).reshape(S_SUB,1).broadcast_to(S_SUB,D) + + attn_out = tl.exp(tmp2) * prev_attn_local.reshape(S_SUB,D) + (tl.exp(tmp3) * cur_attn_local.reshape(S_SUB,D)) + mask1 = None + tl.store(prev_softmax_log_max_sum + softmax_offsets, softmax_log_max_sum, mask1) + tl.store(prev_attn_out + attn_offsets, attn_out.reshape(SIMD_SIZE,), mask1) + +#target_name = torch.npu.get_device_name(torch.npu.current_device()) +guards = {"dummy":None} +def forward_update_triton(prev_attn_out, prev_softmax_log_max_sum, cur_attn_out, cur_softmax_log_max_sum): + (B,N,S,D) = cur_attn_out.shape + #shape is (b,n,s,d) + PREV_ATTN_NSTRIDE = prev_attn_out.stride()[1] + PREV_SOFTMAX_NSTRIDE = prev_softmax_log_max_sum.stride()[1] + CUR_ATTN_NSTRIDE = cur_attn_out.stride()[1] + CUR_SOFTMAX_NSTRIDE = cur_softmax_log_max_sum.stride()[1] + + + + device_id = cur_attn_out.device.index + device = "npu:" +str(device_id) + + tl_fn_forward_update_la[NBLOCKS,1,1](prev_attn_out, prev_softmax_log_max_sum, cur_attn_out, cur_softmax_log_max_sum, B=B, N=N, S=S, D=D, + PREV_ATTN_NSTRIDE=PREV_ATTN_NSTRIDE, PREV_SOFTMAX_NSTRIDE=PREV_SOFTMAX_NSTRIDE, CUR_ATTN_NSTRIDE=CUR_ATTN_NSTRIDE, + CUR_SOFTMAX_NSTRIDE=CUR_SOFTMAX_NSTRIDE, S_NBLOCKS=NBLOCKS, S_SUB=S_SUB_SIZE, debug=True) + diff --git a/third_party/ascend/examples/pytest_ut/conftest.py b/third_party/ascend/examples/pytest_ut/conftest.py new file mode 100644 index 000000000..5e93b7182 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/conftest.py @@ -0,0 +1,14 @@ +import pytest +import torch + + +@pytest.fixture(scope="session", autouse=True) +def assign_npu(worker_id): + npu_count = torch.npu.device_count() + if worker_id == "master": + npu_id = 0 + else: + idx = int(worker_id.replace("gw", "")) + npu_id = idx % npu_count + torch.npu.set_device(npu_id) + diff --git a/third_party/ascend/examples/pytest_ut/test_2d_permute.py b/third_party/ascend/examples/pytest_ut/test_2d_permute.py new file mode 100644 index 000000000..93f977e11 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_2d_permute.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl + +import torch +import torch_npu + +def fn(x): + return x.t() + +@triton.jit +def triton_2d_permute(output_ptr, input_ptr, X : tl.constexpr, Y : tl.constexpr): + xindex = tl.arange(0, X * Y) + input_local = tl.load(input_ptr + xindex) + output_local = input_local.reshape(X, Y).trans().reshape(X*Y) + tl.store(output_ptr + xindex, output_local) + + +@pytest.mark.parametrize('X', [32, 64, 256]) +@pytest.mark.parametrize('Y', [16, 32]) +def test_cases(X, Y): + + x = torch.randn((X, Y)).npu() + output1 = fn(x) + output2 = torch.randn(output1.shape, dtype=output1.dtype).npu() + + triton_2d_permute[1, 1, 1](output2, x, X, Y, debug=True) + print(output1) + print(output2) + + torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3) + + + + + + diff --git a/third_party/ascend/examples/pytest_ut/test_3Dgrid.py b/third_party/ascend/examples/pytest_ut/test_3Dgrid.py new file mode 100644 index 000000000..437467573 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_3Dgrid.py @@ -0,0 +1,66 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest + +BLOCK: tl.constexpr=32 + +@triton.jit +def triton_(in_ptr0, out_ptr0, x0_numel, r1_numel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr, + block_id_threshold: tl.constexpr, XBLOCK1: tl.constexpr, num_core: tl.constexpr): + RBLOCK: tl.constexpr = 64 + + block_idx=tl.program_id(0)*tl.num_programs(1)*tl.num_programs(2)+tl.program_id(1)*tl.num_programs(2)+tl.program_id(2) + if (block_idx < block_id_threshold): + offset = block_idx * XBLOCK + loops1 = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB # 32+23 / 24 = 2 + upper = offset + XBLOCK + else: + offset = block_id_threshold * XBLOCK + (block_idx - block_id_threshold) * XBLOCK1 #pid=34 offset = 9*32 + (34-9)*24 = 888 + loops1 = (XBLOCK1 + XBLOCK_SUB - 1) // XBLOCK_SUB #1 + if (block_idx ==num_core -1): + upper = x0_numel + else: + upper = offset + XBLOCK1 # 912 + + base1 = tl.arange(0, XBLOCK_SUB) + base2 = tl.arange(0, RBLOCK) + loops2: tl.constexpr = (r1_numel + RBLOCK - 1) // RBLOCK + for loop1 in range(loops1): + x = offset + (loop1 * XBLOCK_SUB) + base1 + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1[None, :] + x0 = offset + (loop1 * XBLOCK_SUB) + base1[:, None] + xmask = x0 < upper + r1_prime = base2[:, None] + rindex = base2 + r1 = base2[None, :] + rmask = r1 < r1_numel + tmp0 = tl.load(in_ptr0 + (r1 + (64*x0)), rmask & xmask,other=0.0) + + tmp1 = tl.reshape(tmp0, [XBLOCK_SUB, RBLOCK]) + tmp2_tmp = tl.sum(tmp1,1) + tmp2 = tmp2_tmp.reshape(XBLOCK_SUB,1) + + tl.store(out_ptr0 + (x0), tmp2, xmask) + + +guards = {"dummy": None} + +# @pytest.mark.skip(reason="multi-process error, to be fixed.") +@pytest.mark.parametrize("size", [(1025, 64)]) +def test_3dgrid(size): + b = torch.randn((size), dtype=torch.float32).npu() + c = torch.sum(b, dim=1) + + ret = torch.randn((size[0]), dtype=torch.float32).npu() + + triton_[5, 2, 4](b, ret, size[0], size[1], XBLOCK=32, XBLOCK_SUB=16, block_id_threshold=9, XBLOCK1=24, num_core=40, debug=True) + print(c[0:8]) + print(ret[0:8]) + torch.testing.assert_close(c, ret) + print("test 3D launch passed") + +if __name__ == "__main__": + pytest.main([__file__]) + diff --git a/third_party/ascend/examples/pytest_ut/test_abs.py b/third_party/ascend/examples/pytest_ut/test_abs.py new file mode 100644 index 000000000..f44b3051e --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_abs.py @@ -0,0 +1,39 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + +def torch_pointwise(x0): + res = torch.abs(x0) + return res + + +@triton.jit +def triton_abs(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp2 = tl.abs(tmp0) + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['int32', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + triton_abs[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_abs_2.py b/third_party/ascend/examples/pytest_ut/test_abs_2.py new file mode 100644 index 000000000..cfd142318 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_abs_2.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common + +def torch_abs(x0): + res = torch.abs(x0) + return res + +@triton.jit +def triton_abs(in_ptr0, out_ptr0, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.abs(tmp0) + tl.store(out_ptr0 + (x0), tmp1, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float16', (4, 4), 4, 4, 4], + ['float32', (4, 4), 4, 4, 4], + ]) +def test_abs(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype) + y_ref = torch_abs(x0) + tyname = test_common.get_triton_sig_typename(dtype) + + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + x0 = x0.npu() + triton_abs[ncore, 1, 1](x0, y_cal, xblock, xblock_sub, debug=True) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_add.py b/third_party/ascend/examples/pytest_ut/test_add.py new file mode 100644 index 000000000..3cca181f9 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_add.py @@ -0,0 +1,59 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + + +def torch_pointwise(x0, x1): + res = x0 + x1 + return res + + +@triton.jit +def triton_add(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 + tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ] + ) +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_add[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ] + ) +def test_all_blocks_parallel(param_list, monkeypatch): + monkeypatch.setenv("TRITON_ALL_BLOCKS_PARALLEL", "1") + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_add[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) + monkeypatch.delenv("TRITON_ALL_BLOCKS_PARALLEL") diff --git a/third_party/ascend/examples/pytest_ut/test_advance.py b/third_party/ascend/examples/pytest_ut/test_advance.py new file mode 100644 index 000000000..744eb73ae --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_advance.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest + +@triton.jit +def fn_npu_(output_ptr, x_ptr,y_ptr,z_ptr,output_ptr1,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + # idx = tl.arange(0,XB*YB*ZB) + block_ptr_in=tl.make_block_ptr( + base = x_ptr, + shape = (XB,YB,ZB), + strides = (YB*ZB,ZB,1), + offsets = (9,6,5), + block_shape = (XB,YB,ZB), + order = (2,1,0), + ) + bbptr = tl.advance(block_ptr_in,(-9,-6,-5)) + # XB,YB,1 + X = tl.load(bbptr) + # X = tl.load(x_ptr + idx) + # Y = tl.load(y_ptr + idx) + + # xx=tl.view(X,(ZB*YB,XB)) + + oidx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + block_ptr_out=tl.make_block_ptr( + base = output_ptr, + shape = (XB,YB,ZB), + strides = (YB*ZB,ZB,1), + offsets = (0,0,0), + block_shape = (XB,YB,ZB), + order = (2,1,0), + ) + tl.store(block_ptr_out,X) + # tl.store(output_ptr + tl.arange(0,ZB*YB)[:,None]*XB+xidx[None,:], xx) + # tl.store(output_ptr + xidx[:,None]*YB+yidx[None,:], yy) + +@triton.jit +def fn_npu_2d(output_ptr, x_ptr, y_ptr, z_ptr, output_ptr1, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + xoffset = tl.program_id(0) + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(XB, YB), + strides=(YB, 1), + offsets=(6 + xoffset, 5), + block_shape=(XB, YB), + order=(1, 0), + ) + bbptr = tl.advance(block_ptr_in, (-6, -5)) + # XB,YB,1 + X = tl.load(bbptr) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(XB, YB), + strides=(YB, 1), + offsets=(xoffset, 0), + block_shape=(XB, YB), + order=(1, 0), + ) + tl.store(block_ptr_out, X) + +@triton.jit +def fn_npu_3d(output_ptr, x_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr): + block_ptr_in = tl.make_block_ptr( + base=x_ptr, + shape=(XB, YB, ZB), + strides=(YB * ZB, ZB, 1), + offsets=(0, 0, 0), + block_shape=(XB, YB, 2), + order=(2, 1, 0), + ) + + block_ptr_out = tl.make_block_ptr( + base=output_ptr, + shape=(XB, YB, ZB), + strides=(YB * ZB, ZB, 1), + offsets=(0, 0, 0), + block_shape=(XB, YB, 2), + order=(2, 1, 0), + ) + + for _ in range(ZB // 2): + X = tl.load(block_ptr_in, boundary_check=(0, 1, 2)) + tl.store(block_ptr_out, X, boundary_check=(0, 1, 2)) + block_ptr_in = tl.advance(block_ptr_in, (0, 0, 2)) + block_ptr_out = tl.advance(block_ptr_out, (0, 0, 2)) + + +@pytest.mark.parametrize('dtype', ["int32", "float32", "int16"]) +@pytest.mark.parametrize('shape', [(33, 9, 6), (8, 8, 4)]) +def test_advance_with_boundary_check(dtype, shape): + x = torch.randint(low=-128, high=128, size=shape, dtype=eval('torch.' + dtype)).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + a = x + + fn_npu_3d[1, 1, 1](output, x, XB=shape[0], YB=shape[1], ZB=shape[2]) + + torch.testing.assert_close(output, a) + + +@pytest.mark.parametrize('dtype', ["int32", "float32", "int16"]) +@pytest.mark.parametrize('shape', [(1, 3), (3, 1), (1, 13), (13, 1)]) +def test_advance_supplement(dtype, shape): + x = torch.randint(low=-128,high=128,size=shape,dtype=eval('torch.' + dtype)).npu() + y = torch.randint(low=-128,high=128,size=shape,dtype=eval('torch.' + dtype)).npu() + z = torch.randint(low=-128,high=128,size=shape,dtype=eval('torch.' + dtype)).npu() + + output = torch.randint(1, shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + + a = x + + fn_npu_2d[1,1,1](output, x, y, z, output1, XB=shape[0], YB=shape[1], ZB=1) + + torch.testing.assert_close(output, a) + + +paras = [ + ('*fp32',eval('torch.float32'),2,256,16), + ('*fp32',eval('torch.float32'),8,8,4), + ('*fp16',eval('torch.float16'),2,256,16), + ('*fp16',eval('torch.float16'),8,8,4), + ('*i8',eval('torch.int8'),2,256,16), + ('*i8',eval('torch.int8'),8,8,4), +] + +@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', paras) +def test_npu(para_type,data_type,XB,YB,ZB): + + x = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=data_type).npu() + y = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=data_type).npu() + z = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=data_type).npu() + + print(f"shape = {x.shape}") + print(x.dtype) + + output = torch.randint(1, (XB,YB,ZB), dtype=data_type).npu() + output1 = output + print(f"output.dtype={output.dtype}") + + a = x + print(a) + fn_npu_[1,1,1](output,x,y,z,output1, XB=XB, YB=YB, ZB=ZB, debug=True) + print(output) + torch.testing.assert_close(output,a) + +if __name__=="__main__": + pytest.main([__file__]) diff --git a/third_party/ascend/examples/pytest_ut/test_and.py b/third_party/ascend/examples/pytest_ut/test_and.py new file mode 100644 index 000000000..418bc1942 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_and.py @@ -0,0 +1,42 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_and(x0, x1): + res = x0 & x1 + return res + + +@triton.jit +def triton_and(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index) + tmp1 = tl.load(in_ptr1 + x_index) + tmp2 = tmp0 & tmp1 + tl.store(out_ptr0 + x_index, tmp2) + + +@pytest.mark.parametrize('param_list', + [ + ['int32', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_and(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_and(x0, x1) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_and[ncore, 1, 1](x0, x1, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_arange.py b/third_party/ascend/examples/pytest_ut/test_arange.py new file mode 100644 index 000000000..e105adfe5 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_arange.py @@ -0,0 +1,114 @@ +import math +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import test_common + + +def torch_arange(start, end): + TRITON_MAX_TENSOR_NUMEL = 1048576 + if end < start: + raise ValueError("arange's end argument must be greater than the start argument") + if end - start > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"end - start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}") + return torch.arange(start, end) + + +def torch_arange_access(start, end): + z = torch.zeros([end], dtype=torch.int32).npu() + v = torch.arange(start, end).npu() + z[start:end] = v + return z + + +@triton.jit +def triton_arange(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) + tl.store(z + off, val) + + +@triton.jit +def triton_arange_access(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): + off = tl.arange(START, END) + val = tl.arange(START, END) + tl.store(z + off, val) + + +@pytest.mark.parametrize('param_list', + [ + [0, 128], + [7, 128], + [128, 1024], + ] + ) +def test_case(param_list): + start, end = param_list + shape = [end - start] + block = end - start + dtype = 'int32' + + y_ref = torch_arange(start, end) + y_cal = torch.zeros(shape, dtype=torch.int32).npu() + + triton_arange[(1, )](y_cal, START = start, END = end, BLOCK = block) + + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.parametrize('param_list', + [ + [0, 128], + [7, 128], + [128, 1024], + ] + ) +def test_case_access(param_list): + start, end = param_list + shape = [end] + block = end - start + dtype = 'int32' + + y_ref = torch_arange_access(start, end) + y_cal = torch.zeros(shape, dtype=torch.int32).npu() + + triton_arange_access[(1, )](y_cal, START = start, END = end, BLOCK = block) + + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.parametrize('invalid_param_list', + [ + [0, 10000000], + ] + ) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, + "end - start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = 1048576") +def test_arange_invalid_range(invalid_param_list): + start, end = invalid_param_list + shape = [end - start] + block = end - start + + y_cal = torch.zeros(shape, dtype=torch.int32).npu() + + triton_arange[(1, )](y_cal, START = start, END = end, BLOCK = block) + + +@pytest.mark.parametrize('invalid_param_list', + [ + [1024, 128], + ] + ) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, + "arange's end argument must be greater than the start argument") +def test_arange_invalid_revinput(invalid_param_list): + start, end = invalid_param_list + range = abs(end - start) + shape = [range] + block = range + + y_cal = torch.zeros(shape, dtype=torch.int32).npu() + + triton_arange[(1, )](y_cal, START = start, END = end, BLOCK = block) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_associative_scan.py b/third_party/ascend/examples/pytest_ut/test_associative_scan.py new file mode 100644 index 000000000..8d60b53e8 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_associative_scan.py @@ -0,0 +1,174 @@ +import math +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +import test_common + + +def torch_func(x, dim, reverse): + if reverse: + x = torch.flip(x, [dim]) + res = torch.cumsum(x, dim=dim) + return res + + +def combine_fn_test_torch(a, b, combine_fn): + return torch.maximum(a, b) + + +def torch_func_scan(x: torch.Tensor, dim: int, combine_fn='maximum', reverse=False): + """ + PyTorch implements associative_scan, with semantics fully aligned with Triton. + """ + dim = dim % x.ndim + + if reverse: + x = x.flip(dim) + + N = x.size(dim) + tensors = torch.unbind(x, dim=dim) + + outputs = [] + carry = tensors[0] + outputs.append(carry) + + for i in range(1, N): + carry = combine_fn_test_torch(tensors[i], carry, combine_fn) + outputs.append(carry) + + output = torch.stack(outputs, dim=dim) + + if reverse: + output = output.flip(dim) + + return output + + +@triton.jit +def combine_fn_test(a, b): + return tl.maximum(a, b) + + +@triton.jit +def triton_kernel_1d_scan( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + XBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + idx = tl.arange(0, XBLOCK) + x = tl.load(in_ptr0 + idx) + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=combine_fn_test) + tl.store(out_ptr0 + idx, ret) + + +@triton.jit +def triton_kernel_2d_scan( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx = idx_x[:, None] * numel_r + idx_r[None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=combine_fn_test) + tl.store(out_ptr0 + idx, ret) + + +@triton.jit +def triton_kernel_3d_scan( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + numel_z: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, + ZBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + tl.static_assert( + numel_z == ZBLOCK, "numel_z must be equal to ZBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx_z = tl.arange(0, ZBLOCK) + idx = idx_x[:, None, None] * numel_r * numel_z + idx_r[None, :, None] * numel_z + idx_z[None, None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.associative_scan(x, axis=dim, reverse=reverse, combine_fn=combine_fn_test) + tl.store(out_ptr0 + idx, ret) + + +def triton_func_scan(x, dim, reverse): + res = torch.empty_like(x) + print(f"res.dtype = {res.dtype}") + shape = x.size() + if len(shape) == 1: + if dim >= 1: + pytest.skip("dim >= 1 for 1D tensor, skipping.") + triton_kernel_1d_scan[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[0] + ) + elif len(shape) == 2: + if dim >= 2: + pytest.skip("dim >= 2 for 2D tensor, skipping.") + triton_kernel_2d_scan[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1] + ) + elif len(shape) == 3: + if dim >= 3: + pytest.skip("dim >= 3 for 3D tensor, skipping.") + triton_kernel_3d_scan[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[2], x.shape[0], x.shape[1], x.shape[2] + ) + else: + pytest.skip(f"This testcase unsupported tensor dimension: {len(shape)}") + + return res + + +@pytest.mark.parametrize("dtype", ['int32', 'float32']) +@pytest.mark.parametrize("shape", [(128,), (8, 4), (128, 4, 16)]) +@pytest.mark.parametrize("dim", [0, 1, 2]) +@pytest.mark.parametrize("combine_fn", ['maximum', ]) +@pytest.mark.parametrize("reverse", [False]) +def test_scan(dtype, shape, dim, combine_fn, reverse): + torch.manual_seed(0) + x = test_common.generate_tensor(shape=shape, dtype=dtype) + x_gold = x + cpu_res = torch_func_scan(x_gold, dim, combine_fn, reverse) + print(f"cpu_res: {cpu_res}") + + x_npu = x.npu() + triton_res = triton_func_scan(x_npu, dim, reverse) + print(f"triton_res: {triton_res}") + + test_common.validate_cmp(dtype, triton_res, cpu_res) + print(f"Validate PASS") \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_atan.py b/third_party/ascend/examples/pytest_ut/test_atan.py new file mode 100644 index 000000000..70ae23afa --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_atan.py @@ -0,0 +1,80 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +import triton.language.extra.ascend.libdevice as libdevice + +def standard_unary(x0, dtype): + res = torch.atan(x0) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + ret = libdevice.atan(x) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + # (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + print(out) + + test_common.validate_cmp(sigtype, out, ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_atan2.py b/third_party/ascend/examples/pytest_ut/test_atan2.py new file mode 100644 index 000000000..dda0f9638 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_atan2.py @@ -0,0 +1,108 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +from triton.language import math + + +def standard_unary(x0, y0, dtype): + res = torch.atan2(y0, x0) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = math.atan2(y, x) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_atan2_common(dtype, sigtype, N, NUMEL): + if N == 0: + return + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + y0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, y0, dtype) + x0 = x0.npu() + y0 = y0.npu() + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, y0, out, N=N, NUMEL=NUMEL, debug=True) + + test_common.validate_cmp(sigtype, out, ans) + +input_vals = [ + (0.0, 1.0), + (0.0, -1.0), +] + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes[0:2]) +@pytest.mark.parametrize('X,Y', input_vals) +def test_atan2_special(dtype, sigtype, N, NUMEL, X, Y): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = torch.full((N,), X, dtype=eval(f'torch.{sigtype}')) + y0 = torch.full((N,), Y, dtype=eval(f'torch.{sigtype}')) + + ans = standard_unary(x0, y0, dtype) + x0 = x0.npu() + y0 = y0.npu() + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, y0, out, N=N, NUMEL=NUMEL, debug=True) + + test_common.validate_cmp(sigtype, out, ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_add.py b/third_party/ascend/examples/pytest_ut/test_atomic_add.py new file mode 100644 index 000000000..27505f847 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_atomic_add.py @@ -0,0 +1,135 @@ +import triton +import triton.language as tl +import pytest +import test_common +import torch +import torch_npu + +@triton.jit +def atomic_add(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + yindex = tl.arange(0, BLOCK_SIZE)[:] + xmask = xindex < n_elements + x0 = xindex + x1 = yindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.atomic_add(out_ptr0 + (x1), tmp0, xmask) + tl.store(out_ptr1 + (x1), tmp1, xmask) + + +@triton.jit +def atomic_add_supply( + in_ptr0, out_ptr0, n_elements, BLOCK_SIZE: tl.constexpr +): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + yindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + xmask = xindex < n_elements + x0 = xindex + x1 = yindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.atomic_add(out_ptr0 + (x1), tmp0, xmask) +@pytest.mark.parametrize('param_list', + [ + ['int16', (32, 32), 2], + ['int8', (32, 32), 2], + ['float32', (32, 32), 2], + ['float16', (64, 64), 4], + ['float32', (128, 128), 8], + ['float16', (128, 128), 16], + ['float32', (32768, 16), 32], + ] + ) +def test_atomic_add(param_list): + dtype, shape, ncore = param_list + block_size = shape[0] * shape[1] / ncore + split_size = shape[0] // ncore + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype = eval(f'torch.{dtype}')).npu() + x1 = torch.full((split_size, shape[1]), 2, dtype = eval(f'torch.{dtype}')).npu() + y = torch.full((split_size, shape[1]), -10, dtype = eval(f'torch.{dtype}')).npu() + + y_ref = x1 + 0 + x1_ref = x1 + ncore * x0_value + + n_elements = shape[0] * shape[1] + atomic_add[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, x1, x1_ref) + +@pytest.mark.parametrize('invalid_param_list', + [ + ['int64', (32, 32), 2], + ] + ) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "not support int64") +def test_atomic_add_invalid(invalid_param_list): + dtype, shape, ncore = invalid_param_list + block_size = shape[0] * shape[1] / ncore + split_size = shape[0] // ncore + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype = eval(f'torch.{dtype}')).npu() + x1 = torch.full((split_size, shape[1]), 2, dtype = eval(f'torch.{dtype}')).npu() + y = torch.full((split_size, shape[1]), -10, dtype = eval(f'torch.{dtype}')).npu() + y_ref = x1 + 0 + x1_ref = x1 + ncore * x0_value + n_elements = shape[0] * shape[1] + atomic_add[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, x1, x1_ref) + +@triton.jit +def atomic_add_2d(in_ptr0, out_ptr0, out_ptr1, numel_0, numel_1, BLOCK_SIZE_0 : tl.constexpr, BLOCK_SIZE_1 : tl.constexpr): + pid = tl.program_id(0) + idx0_in = pid * BLOCK_SIZE_0 + tl.arange(0, BLOCK_SIZE_0)[:, None] + idx0_out = tl.arange(0, BLOCK_SIZE_0)[:, None] + idx1 = tl.arange(0, BLOCK_SIZE_1)[None, :] + idx_in = idx0_in * BLOCK_SIZE_1 + idx1 + idx_out = idx0_out * BLOCK_SIZE_1 + idx1 + msk_in = (idx0_in < numel_0) & (idx1 < numel_1) + msk_out = (idx0_out < numel_0) & (idx1 < numel_1) + tmp0 = tl.load(in_ptr0 + idx_in, msk_in) + tmp1 = tl.atomic_add(out_ptr0 + idx_out, tmp0, msk_out) + tl.store(out_ptr1 + idx_out, tmp1, msk_out) + +@pytest.mark.parametrize('param_list', + [ + ['float32', (32, 32), 2], + ] + ) +def test_atomic_add_2d(param_list): + dtype, shape, ncore = param_list + split_size = shape[0] // ncore + block_size_0 = split_size + block_size_1 = shape[1] + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype = eval('torch.float32')).npu() + x1 = torch.full((split_size, shape[1]), 2, dtype = eval('torch.float32')).npu() + y = torch.full((split_size, shape[1]), -10, dtype = eval('torch.float32')).npu() + + y_ref = x1 + 0 + x1_ref = x1 + ncore * x0_value + + atomic_add_2d[ncore, 1, 1](x0, x1, y, shape[0], shape[1], BLOCK_SIZE_0=block_size_0, BLOCK_SIZE_1=block_size_1) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@pytest.mark.parametrize('shape', [(3, 1), (13, 1), (32, 1), (256, 1)]) +@pytest.mark.parametrize('dtype', ['float32']) +def test_atomic_add_2d_supply(dtype, shape): + ncore = 1 + block_size = shape[0] * shape[1] / ncore + split_size = shape[0] // ncore + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype=eval('torch.' + dtype)).npu() + x1 = torch.full((split_size, shape[1]), 2, dtype=eval('torch.' + dtype)).npu() + + y_ref = x1 + 0 + x1_ref = x1 + ncore * x0_value + + n_elements = shape[0] * shape[1] + atomic_add_supply[shape[0], 1, 1](x0, x1, n_elements, BLOCK_SIZE=shape[1]) + test_common.validate_cmp(dtype, x1, x1_ref) + +if __name__ == "__main__": + param_list = ['float32', (32, 32), 2] + test_atomic_add_2d(param_list) diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_and.py b/third_party/ascend/examples/pytest_ut/test_atomic_and.py new file mode 100644 index 000000000..97518438d --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_atomic_and.py @@ -0,0 +1,48 @@ +import triton +import triton.language as tl +import pytest +import test_common +import torch +import torch_npu + + +@triton.jit +def atomic_and(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + yindex = tl.arange(0, BLOCK_SIZE)[:] + xmask = xindex < n_elements + x0 = xindex + x1 = yindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.atomic_and(out_ptr0 + (x1), tmp0, xmask) + tl.store(out_ptr1 + (x1), tmp1, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['int32', (32, 32), 2], + ['int16', (32, 32), 2], + ['int8', (16, 16), 4], + ] + ) +def test_atomic_and(param_list): + dtype, shape, ncore = param_list + block_size = shape[0] * shape[1] // ncore + split_size = shape[0] // ncore + + val = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).npu() + + pointer = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).npu() + pointer_old = torch.full_like(pointer, -10).npu() + pointer_ref = pointer.clone() + + for i in range(ncore - 1): + pointer_ref &= val[(i * split_size):((i + 1) * split_size)] + + pointer_ref_last = pointer_ref.clone() + pointer_ref &= val[((ncore - 1) * split_size):(ncore * split_size)] + + n_elements = shape[0] * shape[1] + atomic_and[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_cas.py b/third_party/ascend/examples/pytest_ut/test_atomic_cas.py new file mode 100644 index 000000000..97415570e --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_atomic_cas.py @@ -0,0 +1,59 @@ +import triton +import triton.language as tl +import pytest +import test_common +import torch +import torch_npu + + +@triton.jit +def atomic_cas(in_ptr0, in_ptr1, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + yindex = tl.arange(0, BLOCK_SIZE)[:] + xmask = xindex < n_elements + x0 = xindex + x1 = yindex + val = tl.load(in_ptr0 + (x0), xmask) + cmp = tl.load(in_ptr1 + (x0), xmask) + tmp1 = tl.atomic_cas(out_ptr0 + (x1), cmp, val) + tl.store(out_ptr1 + (x1), tmp1, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['int16', (8, 8), 2], + ['int32', (32, 32), 6], + ['int64', (32, 32), 2], + ['float32', (32, 32), 2], + ['float16', (64, 64), 4], + ['float32', (128, 128), 8], + ['float16', (128, 128), 16], + ] + ) +def test_atomic_cas(param_list): + dtype, shape, ncore = param_list + block_size = shape[0] * shape[1] // ncore + split_size = shape[0] // ncore + + import random + cmp_val = [random.randint(0, 10) for _ in range(ncore)] + + cmp = torch.ones(split_size, shape[1], dtype=eval(f'torch.{dtype}')).to().npu() * cmp_val[0] + for i in range(1, ncore): + append = torch.ones(split_size, shape[1], dtype=eval(f'torch.{dtype}')).to().npu() * cmp_val[i] + cmp = torch.cat([cmp, append], dim=0) + + val = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).npu() + + pointer = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).npu() + pointer_old = torch.full_like(pointer, -10).npu() + pointer_ref = pointer.clone() + + for i in range(ncore): + val_subview = val[(i * split_size):((i + 1) * split_size)] + pointer_ref = torch.where(pointer_ref == cmp_val[i], val_subview, pointer_ref) + + n_elements = shape[0] * shape[1] + atomic_cas[ncore, 1, 1](val, cmp, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_max.py b/third_party/ascend/examples/pytest_ut/test_atomic_max.py new file mode 100644 index 000000000..76b447887 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_atomic_max.py @@ -0,0 +1,114 @@ +import triton +import triton.language as tl +import pytest +import test_common +import torch +import torch_npu +import numpy as np + + +@triton.jit +def triton_test_fn_atomic_max_dma( + in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr +): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + yindex = tl.arange(0, BLOCK_SIZE)[:] + xmask = xindex < n_elements + x0 = xindex + x1 = yindex + tmp0 = tl.load(in_ptr0 + (x0)) + # only set mask of atomic_max + tl.atomic_max(out_ptr0 + (x1), tmp0, xmask) + + +@triton.jit +def triton_test_fn_atomic_max_dma_supply( + in_ptr0, out_ptr0, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr +): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + yindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + xmask = xindex < n_elements + x0 = xindex + x1 = yindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.atomic_max(out_ptr0 + (x1), tmp0, xmask) +# torch.max do not support int +@pytest.mark.parametrize('param_list', + [ + ['int16', (32, 32), 2], + ['float16', (32, 32), 2], + ['float32', (32, 32), 2], + ['float32', (128, 128), 8], + ['float32', (32768, 16), 32], + ['int32', (32, 32), 2], + ['int32', (128, 128), 8], + ['int32', (32768, 16), 32], + ] + ) +def test_atomic_max(param_list): + dtype, shape, ncore = param_list + block_size = shape[0] * shape[1] / ncore + split_size = shape[0] // ncore + # old size: (32768, 256) + # tensor of (1024, 256) is too big, and it will lead to failure in the backend + # so here we make it smaller + x0 = test_common.generate_tensor(shape, dtype) + x1 = test_common.generate_tensor((split_size, shape[1]), dtype) + y = test_common.generate_tensor((split_size, shape[1]), dtype) + + merged_tensor = torch.cat((x0, x1), dim=0) + chunks = torch.stack(torch.chunk(merged_tensor, ncore+1, dim=0)) + x1_ref = torch.max(chunks, dim=0)[0] + x0 = x0.npu() + x1 = x1.npu() + y = y.npu() + + n_elements = shape[0] * shape[1] + triton_test_fn_atomic_max_dma[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@pytest.mark.parametrize('invalid_param_list', + [ + ['int64', (32, 32), 2], + ] + ) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "not support int64") +def test_atomic_max_invalid(invalid_param_list): + dtype, shape, ncore = invalid_param_list + block_size = shape[0] * shape[1] / ncore + split_size = shape[0] // ncore + x0 = test_common.generate_tensor(shape, dtype) + x1 = test_common.generate_tensor((split_size, shape[1]), dtype) + y = test_common.generate_tensor((split_size, shape[1]), dtype) + + merged_tensor = torch.cat((x0, x1), dim=0) + chunks = torch.stack(torch.chunk(merged_tensor, ncore+1, dim=0)) + x1_ref = torch.max(chunks, dim=0)[0] + x0 = x0.npu() + x1 = x1.npu() + y = y.npu() + + n_elements = shape[0] * shape[1] + triton_test_fn_atomic_max_dma[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, x1, x1_ref) + + +@pytest.mark.parametrize('shape', [(3, 1), (13, 1), (32, 1), (256, 1)]) +@pytest.mark.parametrize('dtype', ['float32']) +def test_atomic_max_2d_supply(dtype, shape): + # old size: (32768, 256) + # tensor of (1024, 256) is too big, and it will lead to failure in the backend + # so here we make it smaller + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + + x1_ref = torch.maximum(x0, x1) + + n_elements = shape[0] * shape[1] + triton_test_fn_atomic_max_dma_supply[shape[0], 1, 1](x0, x1, n_elements, BLOCK_SIZE=shape[1]) + test_common.validate_cmp(dtype, x1, x1_ref) +# if __name__ == "__main__": +# test_atomic_max(['int32', (8, 8), 2]) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_or.py b/third_party/ascend/examples/pytest_ut/test_atomic_or.py new file mode 100644 index 000000000..53f5e4ca0 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_atomic_or.py @@ -0,0 +1,48 @@ +import triton +import triton.language as tl +import pytest +import test_common +import torch +import torch_npu + + +@triton.jit +def atomic_or(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + yindex = tl.arange(0, BLOCK_SIZE)[:] + xmask = xindex < n_elements + x0 = xindex + x1 = yindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.atomic_or(out_ptr0 + (x1), tmp0, xmask) + tl.store(out_ptr1 + (x1), tmp1, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['int32', (32, 32), 2], + ['int16', (32, 32), 2], + ['int8', (16, 16), 4], + ] + ) +def test_atomic_or(param_list): + dtype, shape, ncore = param_list + block_size = shape[0] * shape[1] // ncore + split_size = shape[0] // ncore + + val = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).npu() + + pointer = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).npu() + pointer_old = torch.full_like(pointer, -10).npu() + pointer_ref = pointer.clone() + + for i in range(ncore - 1): + pointer_ref |= val[(i * split_size):((i + 1) * split_size)] + + pointer_ref_last = pointer_ref.clone() + pointer_ref |= val[((ncore - 1) * split_size):(ncore * split_size)] + + n_elements = shape[0] * shape[1] + atomic_or[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_xchg.py b/third_party/ascend/examples/pytest_ut/test_atomic_xchg.py new file mode 100644 index 000000000..e761d071d --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_atomic_xchg.py @@ -0,0 +1,52 @@ +import triton +import triton.language as tl +import pytest +import test_common +import torch +import torch_npu + + +@triton.jit +def atomic_xchg(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + yindex = tl.arange(0, BLOCK_SIZE)[:] + xmask = xindex < n_elements + x0 = xindex + x1 = yindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.atomic_xchg(out_ptr0 + (x1), tmp0, xmask) + tl.store(out_ptr1 + (x0), tmp1, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['int32', (32, 32), 2], + ['int16', (16, 16), 2], + ['int8', (32, 32), 2], + ['float32', (16, 16), 2], + ['float16', (64, 64), 4], + ['float32', (128, 128), 8], + ['float16', (128, 128), 16], + ['float32', (32768, 16), 32], + ] + ) +def test_atomic_xchg(param_list): + dtype, shape, ncore = param_list + block_size = shape[0] * shape[1] // ncore + split_size = shape[0] // ncore + + val = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).npu() + + pointer = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).npu() + pointer_ref = pointer.clone() + pointer_old = torch.full_like(val, -10).npu() + pointer_old_ref = pointer_old.clone() + + pointer_ref = val[((ncore - 1) * split_size):(ncore * split_size)].clone() + pointer_old_ref[0:split_size] = pointer + pointer_old_ref[split_size:((ncore - 1) * split_size)] = val[0:(ncore - 2) * split_size] + + n_elements = shape[0] * shape[1] + atomic_xchg[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_xor.py b/third_party/ascend/examples/pytest_ut/test_atomic_xor.py new file mode 100644 index 000000000..4a49b2eb7 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_atomic_xor.py @@ -0,0 +1,54 @@ +import triton +import triton.language as tl +import pytest +import test_common +import torch +import torch_npu + + +@triton.jit +def atomic_xor(in_ptr0, out_ptr0, out_ptr1, n_elements, BLOCK_SIZE: tl.constexpr): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + yindex = tl.arange(0, BLOCK_SIZE)[:] + xmask = xindex < n_elements + x0 = xindex + x1 = yindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.atomic_xor(out_ptr0 + (x1), tmp0, xmask) + tl.store(out_ptr1 + (x1), tmp1, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['int32', (32, 32), 2], + ['int16', (32, 32), 7], + ['int8', (32, 32), 10], + ] + ) +def test_atomic_xor(param_list): + dtype, shape, ncore = param_list + block_size = shape[0] * shape[1] // ncore + split_size = shape[0] // ncore + + # 初始化原始值 val,全为 0b0011(十进制 3) + val_value = 3 + val = torch.full(shape, val_value, dtype=eval(f'torch.{dtype}')).npu() + + # 每个线程使用不同输入值 x1,全为 0b0101(十进制 5) + pointer_value = 5 + pointer = torch.full((split_size, shape[1]), pointer_value, dtype=eval(f'torch.{dtype}')).npu() + pointer_old = torch.full_like(pointer, -10) + + # 原子异或后:val ^= pointer 每个线程执行一次 + # 因为异或操作具有可逆性和对称性,参考更新次数 + # 所以参考值为 val_value ^ pointer_value ^ pointer_value ^ ...(ncore 次) + pointer_result = pointer_value + for _ in range(ncore): + pointer_result ^= val_value + + pointer_ref = torch.full_like(pointer, pointer_result) + + n_elements = shape[0] * shape[1] + atomic_xor[ncore, 1, 1](val, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, pointer, pointer_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_attn_cp.py b/third_party/ascend/examples/pytest_ut/test_attn_cp.py new file mode 100644 index 000000000..a614b30ad --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_attn_cp.py @@ -0,0 +1,482 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + + +import time +import torch +import torch_npu +import pytest +import sys + +sys.path.append("../../../") + +from attn_cp_triton_kernel_3d import forward_update_triton, backward_update_triton +from attn_cp_triton_kernel_3d_la import forward_update_triton as forward_update_triton_la + + +def collect_time(model, example_inputs, times: int = 1): + stream = torch.npu.current_stream() + warmup = 1 + stream.synchronize() + for i in range(times + warmup): + out = model(*example_inputs) + if i < warmup: + stream.synchronize() + t0 = time.perf_counter() + else: + t1 = time.perf_counter() + stream.synchronize() + t1 = time.perf_counter() + # GC the result after timing + assert out is not None + return (t1 - t0) / times + + +def print_performance(fn, args=(), times=10, repeat=10, baseline=1.0): + stream = torch.npu.current_stream() + + start = time.perf_counter() + + stream.synchronize() + + for _ in range(repeat * times): + fn(*args) + + stream.synchronize() + + end = time.perf_counter() + took = (end - start) / (times * repeat) + print(f"{took:.6f}") + return took + + +def profile_test(fn, fn_triton, args=(), name="gen_fn", times=10, repeat=10): + print(f"--------------------profiling {name} for {times * repeat} times--------------------") + stream = torch.npu.current_stream() + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, ) + prof = torch_npu.profiler.profile( + activities=[ + # torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU], + record_shapes=False, + profile_memory=False, + with_stack=False, + schedule=torch_npu.profiler.schedule(wait=0, warmup=1, active=100, repeat=1, skip_first=10), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./result_dir"), + experimental_config=experimental_config) + + stream.synchronize() + prof.start() + + for _ in range(times * repeat): + fn_triton(*args) + prof.step() + + prof.stop() + + +def benchmark_test(fn, fn_triton, args=(), name="gen_fn", times=10, repeat=10, profile=False): + print(f"--------------------benchmark_{name} for {times * repeat} times--------------------") + stream = torch.npu.current_stream() + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, ) + prof = torch_npu.profiler.profile( + activities=[ + # torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU], + record_shapes=False, + profile_memory=False, + with_stack=False, + schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=100, repeat=1, skip_first=10), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./result_dir"), + experimental_config=experimental_config) + + stream.synchronize() + prof.start() + start = time.perf_counter() + for _ in range(times * repeat): + fn_triton(*args) + if profile: + prof.step() + stream.synchronize() + end = time.perf_counter() + time_compiled = (end - start) / (times * repeat) + print(f"{time_compiled:.6f}") + + print(f"Runing eager {name} for {times * repeat} times") + start = time.perf_counter() + for _ in range(times * repeat): + fn(*args) + if profile: + prof.step() + stream.synchronize() + end = time.perf_counter() + time_eager = (end - start) / (times * repeat) + print(f"{time_eager:.6f}") + + time_eager *= 1000000 + time_compiled *= 1000000 + print( + f"Accelerated: {(time_eager - time_compiled) / time_compiled * 100:.4f}% eager takes {time_eager:.3f} us, triton takes {time_compiled:.3f} us") + + return time_eager, time_compiled + + +from einops import rearrange + + +def trans_BNSD2SBH(x): + """Trans data layout from BNSD to SBH""" + return rearrange(x, 'b n s d -> s b (n d)').contiguous() + + +def broadcast_and_trans_BNSD2SBH(x, h): + """broadcast and trans a tensor from [b, n, s, 8] to [s, b, h]""" + n = x.shape[1] + d = h // n + # [b, n, s, 8] -> [b, n, s, d] + new_x = x[..., 0].unsqueeze(3) + new_x = new_x.repeat(1, 1, 1, d) + return trans_BNSD2SBH(new_x) + + +def forward_update(prev_attn_out, prev_softmax_max, prev_softmax_sum, + cur_attn_out, cur_softmax_max, cur_softmax_sum): + # update softmax_max + org_dtype = prev_attn_out.dtype + softmax_max = torch.maximum(prev_softmax_max, cur_softmax_max) + prev_scale = torch.exp(prev_softmax_max - softmax_max) + cur_scale = torch.exp(cur_softmax_max - softmax_max) + # update softmax_sum + prev_softmax_sum_scaled = prev_softmax_sum * prev_scale + cur_softmax_sum_scaled = cur_softmax_sum * cur_scale + softmax_sum = prev_softmax_sum_scaled + cur_softmax_sum_scaled + # out updating scale + prev_out_scale = prev_softmax_sum_scaled / softmax_sum + cur_out_scale = cur_softmax_sum_scaled / softmax_sum + # [b, n, s, 8] -> [s, b, h] + prev_out_scale_sbh = broadcast_and_trans_BNSD2SBH(prev_out_scale, prev_attn_out.shape[-1]) + cur_out_scale_sbh = broadcast_and_trans_BNSD2SBH(cur_out_scale, prev_attn_out.shape[-1]) + # update output + attn_out = prev_attn_out * prev_out_scale_sbh + cur_attn_out * cur_out_scale_sbh + attn_out = attn_out.to(org_dtype) + return attn_out, softmax_max, softmax_sum + + +def prove_forward_update(): + def data_validation(prev_softmax_max, cur_softmax_max, prev_softmax_sum, cur_softmax_sum, prev_attn_out, + cur_attn_out): + + (tt_attn_out, tt_softmax_max, tt_softmax_sum) = forward_update_triton(prev_attn_out, + prev_softmax_max, prev_softmax_sum, + cur_attn_out, cur_softmax_max, + cur_softmax_sum) + + (attn_out, softmax_max, softmax_sum) = forward_update(prev_attn_out, prev_softmax_max, prev_softmax_sum, + cur_attn_out, + cur_softmax_max, cur_softmax_sum) + + try: + assert torch.equal(softmax_max, tt_softmax_max) + print("max comparition passed.") + assert torch.equal(softmax_sum, tt_softmax_sum) + print("sum comparition passed.") + torch.testing.assert_close(attn_out, tt_attn_out) + print("atten comparition passed.") + + except Exception as e: + print(e) + print("comparison not passed") + print( + f"proving finished, attn shape:{prev_attn_out.shape}, stride:{prev_attn_out.stride(), cur_attn_out.stride()}, softmax shape:{prev_softmax_sum.shape}, stride:{prev_softmax_sum.stride(), cur_softmax_sum.stride()}") + + (S, B, H, N) = (4096, 1, 6144, 48) + DS = 2 * S + DTYPE_ATTN = torch.bfloat16 + DTYPE = torch.float32 + F32_BLK_SIZE = 8 + for i in range(1): + print("prove_forward_update round:", i) + prev_attn_out = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + prev_softmax_max = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + prev_softmax_sum = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + cur_attn_out = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + cur_softmax_max = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + cur_softmax_sum = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + + prev_attn_out_s = prev_attn_out.view(2, S, B, H)[1] + prev_softmax_max_s = prev_softmax_max.view(B, N, 2, S, 8)[:, :, 1, :, :] + prev_softmax_sum_s = prev_softmax_sum.view(B, N, 2, S, 8)[:, :, 1, :, :] + cur_attn_out_s = torch.randn((S, B, H), dtype=DTYPE_ATTN).npu() + cur_softmax_max_s = torch.rand((B, N, S), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + cur_softmax_sum_s = torch.rand((B, N, S), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + + print("--------------------prove_forward_update:2S----------------------") + data_validation(prev_softmax_max, cur_softmax_max, prev_softmax_sum, cur_softmax_sum, prev_attn_out, + cur_attn_out) + + print("--------------------prove_forward_update:1S---------------------- ") + data_validation(prev_softmax_max_s, cur_softmax_max_s, prev_softmax_sum_s, cur_softmax_sum_s, prev_attn_out_s, + cur_attn_out_s) + + +def benchmark_forward_update(): + (S, B, H, N) = (4096, 1, 6144, 48) + DS = 2 * S + DTYPE_ATTN = torch.bfloat16 + DTYPE = torch.float32 + F32_BLK_SIZE = 8 + prev_attn_out = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + prev_softmax_max = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + prev_softmax_sum = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + cur_attn_out = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + cur_softmax_max = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + cur_softmax_sum = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + + prev_attn_out_s = prev_attn_out.view(2, S, B, H)[1] + prev_softmax_max_s = prev_softmax_max.view(B, N, 2, S, 8)[:, :, 1, :, :] + prev_softmax_sum_s = prev_softmax_sum.view(B, N, 2, S, 8)[:, :, 1, :, :] + cur_attn_out_s = torch.randn((S, B, H), dtype=DTYPE_ATTN).npu() + cur_softmax_max_s = torch.rand((B, N, S), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + cur_softmax_sum_s = torch.rand((B, N, S), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + + benchmark_test(forward_update, forward_update_triton, args=(prev_attn_out, + prev_softmax_max, prev_softmax_sum, cur_attn_out, + cur_softmax_max, cur_softmax_sum), + name="forward_update_2s", times=10, repeat=10) + + benchmark_test(forward_update, forward_update_triton, args=(prev_attn_out, + prev_softmax_max, prev_softmax_sum, cur_attn_out, + cur_softmax_max, cur_softmax_sum), + name="forward_update_s", times=10, repeat=10) + + +def profile_forward_update(): + (S, B, H, N) = (4096, 1, 6144, 48) + # (S, B, H, N) = (2048,1,1536,12) + DS = 2 * S + DTYPE_ATTN = torch.bfloat16 + DTYPE = torch.float32 + F32_BLK_SIZE = 8 + prev_attn_out = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + prev_softmax_max = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + prev_softmax_sum = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + cur_attn_out = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + cur_softmax_max = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + cur_softmax_sum = torch.rand((B, N, DS), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + + prev_attn_out_s = prev_attn_out.view(2, S, B, H)[1] + prev_softmax_max_s = prev_softmax_max.view(B, N, 2, S, 8)[:, :, 1, :, :] + prev_softmax_sum_s = prev_softmax_sum.view(B, N, 2, S, 8)[:, :, 1, :, :] + cur_attn_out_s = torch.randn((S, B, H), dtype=DTYPE_ATTN).npu() + cur_softmax_max_s = torch.rand((B, N, S), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + cur_softmax_sum_s = torch.rand((B, N, S), dtype=DTYPE).npu().unsqueeze(3).repeat(1, 1, 1, F32_BLK_SIZE) + + profile_test(forward_update, forward_update_triton, args=(prev_attn_out, + prev_softmax_max, prev_softmax_sum, cur_attn_out, + cur_softmax_max, cur_softmax_sum), + name="forward_update_2s", times=10, repeat=10) + + profile_test(forward_update, forward_update_triton, args=(prev_attn_out, + prev_softmax_max, prev_softmax_sum, cur_attn_out, + cur_softmax_max, cur_softmax_sum), + name="forward_update_s", times=10, repeat=10) + + +def backward_update(dq, dk, dv, cur_dq, cur_dk, cur_dv, i=7, rank=1): + cp_size = 8 + if i >= cp_size - rank - 1: + if i == cp_size - 1: + cur_dq = cur_dq.view(dq.shape) + cur_dk = cur_dk.view(dk.shape) + cur_dv = cur_dv.view(dv.shape) + dq.add_(cur_dq) + dk.add_(cur_dk) + dv.add_(cur_dv) + else: + cur_dq = cur_dq.view(dq.shape) + dq.add_(cur_dq) + dk[0].add_(cur_dk) + dv[0].add_(cur_dv) + else: + dq[1].add_(cur_dq) + cur_dk = cur_dk.view(dk.shape) # [2s, b, h] -> [2, s, b, h] + cur_dv = cur_dv.view(dv.shape) + dk.add_(cur_dk) + dv.add_(cur_dv) + + +def prove_backward_update(): + (S, B, H) = (16384, 1, 1536) + DS = 2 * S + DTYPE_ATTN = torch.bfloat16 + DTYPE = torch.float32 + + d_dq = torch.randn((2, S, B, H), dtype=DTYPE_ATTN).npu() + cur_dq = torch.randn((S, B, H), dtype=DTYPE_ATTN).npu() + d_cur_dq = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + + d_dk = torch.randn((2, S, B, H), dtype=DTYPE_ATTN).npu() + cur_dk = torch.randn((S, B, H), dtype=DTYPE_ATTN).npu() + d_cur_dk = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + + d_dv = torch.randn((2, S, B, H), dtype=DTYPE_ATTN).npu() + cur_dv = torch.randn((S, B, H), dtype=DTYPE_ATTN).npu() + d_cur_dv = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + + def data_validate(dq, dk, dv, cur_dq, cur_dk, cur_dv, i, rank): + dq_c = dq.detach().clone() + dk_c = dk.detach().clone() + dv_c = dv.detach().clone() + + if (i == 7): + backward_update(dq, dk, dv, cur_dq, cur_dk, cur_dv, i, rank) + backward_update_triton(dq_c, dk_c, dv_c, cur_dq, cur_dk, cur_dv) + elif (i == 6): + backward_update(dq, dk, dv, cur_dq, cur_dk, cur_dv, i, rank) + backward_update_triton(dq_c, dk_c[0], dv_c[0], cur_dq, cur_dk, cur_dv) + else: + backward_update(dq, dk, dv, cur_dq, cur_dk, cur_dv, i, rank) + backward_update_triton(dq_c[1], dk_c, dv_c, cur_dq, cur_dk, cur_dv) + + torch.testing.assert_close(dq, dq_c) + print("dq comparison passed") + torch.testing.assert_close(dk, dk_c) + print("dk comparison passed") + torch.testing.assert_close(dv, dv_c) + print("passed comparison ") + + print("--------------------prove_backward_update case 0----------------------") + data_validate(d_dq, d_dk, d_dv, d_cur_dq, d_cur_dk, d_cur_dv, 7, 1) + print("--------------------prove_backward_update case 1----------------------") + data_validate(d_dq, d_dk, d_dv, d_cur_dq, cur_dk, cur_dv, 6, 1) + print("--------------------prove_backward_update case 2----------------------") + data_validate(d_dq, d_dk, d_dv, cur_dq, d_cur_dk, d_cur_dv, 5, 1) + + +def benchmark_backward_update(): + (S, B, H) = (16384, 1, 1536) + DS = 2 * S + DTYPE_ATTN = torch.bfloat16 + DTYPE = torch.float32 + + d_dq = torch.randn((2, S, B, H), dtype=DTYPE_ATTN).npu() + cur_dq = torch.randn((S, B, H), dtype=DTYPE_ATTN).npu() + d_cur_dq = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + + d_dk = torch.randn((2, S, B, H), dtype=DTYPE_ATTN).npu() + cur_dk = torch.randn((S, B, H), dtype=DTYPE_ATTN).npu() + d_cur_dk = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + + d_dv = torch.randn((2, S, B, H), dtype=DTYPE_ATTN).npu() + cur_dv = torch.randn((S, B, H), dtype=DTYPE_ATTN).npu() + d_cur_dv = torch.randn((DS, B, H), dtype=DTYPE_ATTN).npu() + # benchmark case 0 + benchmark_test(backward_update, backward_update_triton, args=(d_dq, d_dk, d_dv, d_cur_dq, d_cur_dk, d_cur_dv), + name="backward_update_0") + + +def forward_update_la(prev_attn_out, prev_softmax_log_max_sum, + cur_attn_out, cur_softmax_log_max_sum): + if prev_attn_out is None: + return cur_attn_out, cur_softmax_log_max_sum + softmax_log_max_sum = torch.log(torch.exp(cur_softmax_log_max_sum) + torch.exp(prev_softmax_log_max_sum)) + attn_out = torch.exp(prev_softmax_log_max_sum - softmax_log_max_sum) * prev_attn_out + torch.exp( + cur_softmax_log_max_sum - softmax_log_max_sum) * cur_attn_out + return attn_out, softmax_log_max_sum + + +# simuldate origin code :call foward_update then call copy +def forward_update_copy(prev_attn_out, prev_softmax_log_max_sum, + cur_attn_out, cur_softmax_log_max_sum): + attn_out, softmax = forward_update_la(prev_attn_out, prev_softmax_log_max_sum, + cur_attn_out, cur_softmax_log_max_sum) + prev_attn_out.copy_(attn_out) + prev_softmax_log_max_sum.copy_(softmax) + + +def prove_forward_update_la(): + def data_validation(prev_attn_out, prev_softmax_sum, cur_attn_out, cur_softmax_sum): + + (attn_out, softmax_sum) = forward_update_la(prev_attn_out, prev_softmax_sum, cur_attn_out, cur_softmax_sum) + forward_update_triton_la(prev_attn_out, prev_softmax_sum, cur_attn_out, cur_softmax_sum) + + try: + torch.testing.assert_close(softmax_sum, prev_softmax_sum) + print("softmax comparition passed.") + torch.testing.assert_close(attn_out, prev_attn_out) + print("attn comparition passed.") + + except Exception as e: + print(e) + print("comparison not passed") + + print( + f"proving finished, attn shape:{prev_attn_out.shape}, stride:{prev_attn_out.stride(), cur_attn_out.stride()}, softmax shape:{prev_softmax_sum.shape}, stride:{prev_softmax_sum.stride(), cur_softmax_sum.stride()}") + + (S, B, N, D) = (4096, 1, 48, 128) + DS = 2 * S + DTYPE_ATTN = torch.float32 + DTYPE = torch.float32 + + for i in range(1): + print("round:", i) + attn_out = torch.randn((B, N, DS, D), dtype=DTYPE_ATTN).npu() + softmax_sum = torch.rand((B, N, DS, 1), dtype=DTYPE).npu() + + cur_attn_out = torch.randn((B, N, DS, D), dtype=DTYPE_ATTN).npu() + cur_softmax_sum = torch.rand((B, N, DS, 1), dtype=DTYPE).npu() + + cur_attn_out_s = torch.randn((B, N, S, D), dtype=DTYPE_ATTN).npu() + cur_softmax_sum_s = torch.rand((B, N, S, 1), dtype=DTYPE).npu() + + print("---------------------prove_forward_updat_la 2S-------------------------------") + data_validation(attn_out, softmax_sum, cur_attn_out, cur_softmax_sum) + + print("--------------------prove_forward_update_la 1S------------------------------") + # [b, n, 2s, d] -> [b, n, 2, s, d] + attn_out_s = attn_out.view(*attn_out.shape[:2], 2, attn_out.shape[2] // 2, attn_out.shape[-1]) + # [b, n, 2s, 1] -> [b, n, 2, s, 1] + softmax_sum_s = softmax_sum.view(*softmax_sum.shape[:2], 2, softmax_sum.shape[2] // 2, softmax_sum.shape[-1]) + + data_validation(attn_out_s[:, :, 1], softmax_sum_s[:, :, 1], cur_attn_out_s, cur_softmax_sum_s) + + +def benchmark_forward_update_la(): + (S, B, N, D) = (4096, 1, 48, 128) + DS = 2 * S + DTYPE_ATTN = torch.float32 + DTYPE = torch.float32 + + attn_out = torch.randn((B, N, DS, D), dtype=DTYPE_ATTN).npu() + softmax_sum = torch.rand((B, N, DS, 1), dtype=DTYPE).npu() + cur_attn_out = torch.randn((B, N, DS, D), dtype=DTYPE_ATTN).npu() + cur_softmax_sum = torch.rand((B, N, DS, 1), dtype=DTYPE).npu() + cur_attn_out_s = torch.randn((B, N, S, D), dtype=DTYPE_ATTN).npu() + cur_softmax_sum_s = torch.rand((B, N, S, 1), dtype=DTYPE).npu() + attn_out_s = attn_out.view(*attn_out.shape[:2], 2, attn_out.shape[2] // 2, attn_out.shape[-1]) + softmax_sum_s = softmax_sum.view(*softmax_sum.shape[:2], 2, softmax_sum.shape[2] // 2, softmax_sum.shape[-1]) + + benchmark_test(forward_update_copy, forward_update_triton_la, args=(attn_out_s[:, :, 1], softmax_sum_s[:, :, 1], + cur_attn_out_s, cur_softmax_sum_s), + name="forward_update_la", profile=False, repeat=1000) + +@pytest.mark.skip(reason="attn_cp") +def test_prove_forward_update(): + prove_forward_update() + +@pytest.mark.skip(reason="attn_cp") +def test_prove_forward_update_la(): + prove_forward_update_la() + +@pytest.mark.skip(reason="attn_cp") +def test_prove_backward_update(): + prove_backward_update() + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/third_party/ascend/examples/pytest_ut/test_autotune_auto_prof.py b/third_party/ascend/examples/pytest_ut/test_autotune_auto_prof.py new file mode 100644 index 000000000..7c993daa6 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_autotune_auto_prof.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import os +import shutil +import itertools + +import pytest +import triton +import triton.language as tl +import torch +import torch_npu + + +_AUTO_PROF_DIR1 = "./TEST_AUTO_PROF1" +_AUTO_PROF_DIR2 = "./TEST_AUTO_PROF2" + + +def get_autotune_config(): + configs = [] + block_size_list = [1024, 2048] + + multibuffer_list = [False] + for combo in itertools.product( + block_size_list, + multibuffer_list, + ): + ( + block_size, + multibuffer, + ) = combo + + configs.append( + triton.Config( + { + "BLOCK_SIZE": block_size, + }, + multibuffer=multibuffer, + ) + ) + + return configs + + +@triton.autotune( + configs=get_autotune_config(), + key=["n_elements"], + auto_profile_dir=_AUTO_PROF_DIR1, # auto profile the best configuration and store the result +) +@triton.jit +def add_kernel_1( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +# ASCEND-affinity-aware autotune +@triton.autotune( + configs=[], + key={"x": "n_elements"}, + split_params={"x": "BLOCK_SIZE"}, + tiling_params={}, + low_dims=["x"], + persistent_reduction=False, + dual_reduction=False, + auto_profile_dir=_AUTO_PROF_DIR2, +) +@triton.jit +def add_kernel_2( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +@pytest.fixture(autouse=True, scope="session") +def cleanup_prof_dirs(): + # setup: ensure directories don't exist before test + if os.path.exists(_AUTO_PROF_DIR1): + shutil.rmtree(_AUTO_PROF_DIR1) + if os.path.exists(_AUTO_PROF_DIR2): + shutil.rmtree(_AUTO_PROF_DIR2) + + yield + + # teardown: clean up directories after test + if os.path.exists(_AUTO_PROF_DIR1): + shutil.rmtree(_AUTO_PROF_DIR1) + if os.path.exists(_AUTO_PROF_DIR2): + shutil.rmtree(_AUTO_PROF_DIR2) + + +@pytest.mark.parametrize( + "size, fn, prof_dir", + [ + (98432, add_kernel_1, _AUTO_PROF_DIR1), + (98432, add_kernel_2, _AUTO_PROF_DIR2), + ], +) +def test(size, fn, prof_dir): + x = torch.rand(size, device="npu") + y = torch.rand(size, device="npu") + + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + fn[grid](x, y, output, n_elements) + + assert os.path.exists(prof_dir), f"Profiling directory {prof_dir} was not created!" + + prof_files = os.listdir(prof_dir) + assert len(prof_files) > 0, f"No profiling files found in {prof_dir}" \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_block_ptr.py b/third_party/ascend/examples/pytest_ut/test_block_ptr.py new file mode 100644 index 000000000..7b95c8de6 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_block_ptr.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest + +@triton.jit +def fn_npu_(output_ptr, x_ptr,y_ptr,z_ptr,output_ptr1,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + block_ptr_in=tl.make_block_ptr( + base = x_ptr, + shape = (XB,YB,ZB), + strides = (YB*ZB,ZB,1), + offsets = (0,0,0), + block_shape = (XB,YB,ZB), + order = (2,1,0), + ) + X = tl.load(block_ptr_in) + + oidx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + block_ptr_out=tl.make_block_ptr( + base = output_ptr, + shape = (XB,YB,ZB), + strides = (YB*ZB,ZB,1), + offsets = (0,0,0), + block_shape = (XB,YB,ZB), + order = (2,1,0), + ) + tl.store(block_ptr_out,X) + +paras = [ + ('*fp32',eval('torch.float32'),2,256,16), + ('*fp32',eval('torch.float32'),8,8,4), + ('*fp16',eval('torch.float16'),2,256,16), + ('*fp16',eval('torch.float16'),8,8,4), + ('*i8',eval('torch.int8'),2,256,16), + ('*i8',eval('torch.int8'),8,8,4), +] + +@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', paras) +def test_npu(para_type,data_type,XB,YB,ZB): + + x = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=data_type).npu() + y = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=data_type).npu() + z = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=data_type).npu() + + print(f"shape = {x.shape}") + print(x.dtype) + + output = torch.randint(1, (XB,YB,ZB), dtype=data_type).npu() + output1 = output + print(f"output.dtype={output.dtype}") + + a = x + print(a) + fn_npu_[1,1,1](output,x,y,z,output1, XB=XB, YB=YB, ZB=ZB, debug=True) + print(output) + torch.testing.assert_close(output,a) diff --git a/third_party/ascend/examples/pytest_ut/test_broadcast_op.py b/third_party/ascend/examples/pytest_ut/test_broadcast_op.py new file mode 100644 index 000000000..6f8864063 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_broadcast_op.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu + +NBLOCKS = 1 +XS : tl.constexpr = 128 +YS : tl.constexpr = 4 +ZS : tl.constexpr = 8 +NUMEL : tl.constexpr = XS * ZS + +@triton.jit +def fn_broadcast(output_ptr, x_ptr, length): + col_offsets = tl.arange(0, NUMEL) + input = tl.load(x_ptr + col_offsets) + result = input.reshape((XS, 1, ZS)).broadcast_to((XS, YS, ZS)).reshape((XS * YS * ZS)) + brc_col_offsets = tl.arange(0, NUMEL * YS) + tl.store(output_ptr + brc_col_offsets, result) + +def test_broadcast(): + length = NUMEL + + x = torch.randn((XS, 1, ZS), dtype=torch.float32).npu() + output = torch.randn((XS, YS, ZS), dtype=torch.float32).npu() + fn_broadcast[NBLOCKS,1,1](output, x, length, debug=True) + assert(torch.equal(output, x.repeat(1, YS, 1))) + +if __name__ == "__main__": + test_broadcast() + diff --git a/third_party/ascend/examples/pytest_ut/test_cast_full.py b/third_party/ascend/examples/pytest_ut/test_cast_full.py new file mode 100644 index 000000000..a14b8103e --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_cast_full.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + +Dimensions = tuple[int, int, int] + + +@triton.jit +def cast_to_bool(output_ptr, x_ptr, dims: Dimensions, overflow_mode: tl.constexpr): + xidx = tl.arange(0, dims.XB) + yidx = tl.arange(0, dims.YB) + zidx = tl.arange(0, dims.ZB) + idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + overflow_mode = "trunc" if overflow_mode == 0 else "saturate" + ret = tl.cast(X, dtype=tl.int1, overflow_mode=overflow_mode) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def cast_to_i8(output_ptr, x_ptr, dims: Dimensions, overflow_mode: tl.constexpr): + xidx = tl.arange(0, dims.XB) + yidx = tl.arange(0, dims.YB) + zidx = tl.arange(0, dims.ZB) + idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + overflow_mode = "trunc" if overflow_mode == 0 else "saturate" + ret = tl.cast(X, dtype=tl.int8, overflow_mode=overflow_mode) + + tl.store(output_ptr + idx, ret) + + +def cast_to_i16(output_ptr, x_ptr, dims: Dimensions, overflow_mode: tl.constexpr): + xidx = tl.arange(0, dims.XB) + yidx = tl.arange(0, dims.YB) + zidx = tl.arange(0, dims.ZB) + idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + overflow_mode = "trunc" if overflow_mode == 0 else "saturate" + ret = tl.cast(X, dtype=tl.int16, overflow_mode=overflow_mode) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def cast_to_i32(output_ptr, x_ptr, dims: Dimensions, overflow_mode: tl.constexpr): + xidx = tl.arange(0, dims.XB) + yidx = tl.arange(0, dims.YB) + zidx = tl.arange(0, dims.ZB) + idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + overflow_mode = "trunc" if overflow_mode == 0 else "saturate" + ret = tl.cast(X, dtype=tl.int32, overflow_mode=overflow_mode) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def cast_to_i64(output_ptr, x_ptr, dims: Dimensions, overflow_mode: tl.constexpr): + xidx = tl.arange(0, dims.XB) + yidx = tl.arange(0, dims.YB) + zidx = tl.arange(0, dims.ZB) + idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + overflow_mode = "trunc" if overflow_mode == 0 else "saturate" + ret = tl.cast(X, dtype=tl.int64, overflow_mode=overflow_mode) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def cast_to_fp32(output_ptr, x_ptr, dims: Dimensions, overflow_mode: tl.constexpr): + xidx = tl.arange(0, dims.XB) + yidx = tl.arange(0, dims.YB) + zidx = tl.arange(0, dims.ZB) + idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + overflow_mode = "trunc" if overflow_mode == 0 else "saturate" + ret = tl.cast(X, dtype=tl.float32, overflow_mode=overflow_mode) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def cast_to_fp16(output_ptr, x_ptr, dims: Dimensions, overflow_mode: tl.constexpr): + xidx = tl.arange(0, dims.XB) + yidx = tl.arange(0, dims.YB) + zidx = tl.arange(0, dims.ZB) + idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + overflow_mode = "trunc" if overflow_mode == 0 else "saturate" + ret = tl.cast(X, dtype=tl.float16, overflow_mode=overflow_mode) + + tl.store(output_ptr + idx, ret) + + +@triton.jit +def cast_to_bf16(output_ptr, x_ptr, dims: Dimensions, overflow_mode: tl.constexpr): + xidx = tl.arange(0, dims.XB) + yidx = tl.arange(0, dims.YB) + zidx = tl.arange(0, dims.ZB) + idx = xidx[:, None, None] * YB * ZB + yidx[None, :, None] * ZB + zidx[None, None, :] + + X = tl.load(x_ptr + idx) + overflow_mode = "trunc" if overflow_mode == 0 else "saturate" + ret = tl.cast(X, dtype=tl.bfloat16, overflow_mode=overflow_mode) + + tl.store(output_ptr + idx, ret) + + +import numpy as np + + +def cast_npu(para_type, data_type, to_para, to_dtype, XB, YB, ZB, overflow_mode): + + print(f"TESTING: cast from {para_type} to {to_para} in shape ({XB}, {YB}, {ZB})") + + if para_type == "*i1": + x = torch.randint(low=0, high=2, size=(XB, YB, ZB), dtype=data_type).npu() + elif ( + para_type == "*i8" + or para_type == "*i16" + or para_type == "*i32" + or para_type == "*64" + ): + x = torch.randint(low=-128, high=128, size=(XB, YB, ZB), dtype=data_type).npu() + elif para_type == "*i16": + x = torch.randint( + low=-32768, high=32768, size=(XB, YB, ZB), dtype=data_type + ).npu() + elif para_type == "*i32": + x = torch.randint( + low=-65536, high=65536, size=(XB, YB, ZB), dtype=data_type + ).npu() + elif para_type == "*i64": + x = torch.randint( + low=-65536, high=65536, size=(XB, YB, ZB), dtype=data_type + ).npu() + else: # float + x = torch.randn((XB, YB, ZB), dtype=data_type).npu() + + if to_para == "*i1": + triton_func = cast_to_bool + cmp_type = "bool" + elif to_para == "*i8": + triton_func = cast_to_i8 + cmp_type = "int8" + elif to_para == "*i16": + triton_func = cast_to_i16 + cmp_type = "int16" + elif to_para == "*i32": + triton_func = cast_to_i32 + cmp_type = "int32" + elif to_para == "*i64": + triton_func = cast_to_i64 + cmp_type = "int64" + elif to_para == "*fp16": + triton_func = cast_to_fp16 + cmp_type = "float16" + elif to_para == "*fp32": + triton_func = cast_to_fp32 + cmp_type = "float32" + elif to_para == "*bf16": + triton_func = cast_to_bf16 + cmp_type = "bfloat16" + + output = torch.randint(1, (XB, YB, ZB), dtype=to_dtype).npu() + + a = x.to(to_dtype) + dims = Dimensions(XB=XB, YB=YB, ZB=ZB) + + triton_func[1, 1, 1](output, x, dims, overflow_mode) + + test_common.validate_cmp(cmp_type, a, output, overflow_mode) + + +def test_cast_high_priority_dtype(): + + typelist = [ + (torch.int8, "*i8"), + (torch.float32, "*fp32"), + (torch.float16, "*fp16"), + ] + + overflow_mode = [ + 0, # "trunc", + 1, # "saturate", + ] + + shapes = [(8, 32, 32)] + ContinueList = [] + for src in typelist: + for dst in typelist: + if src != dst and (src[1], dst[1]) not in ContinueList: + for shape in shapes: + for mode in overflow_mode: + ( + src[1], + src[0], + dst[1], + dst[0], + shape[0], + shape[1], + shape[2], + mode, + ) + + print("test_cast_full passed") diff --git a/third_party/ascend/examples/pytest_ut/test_cat.py b/third_party/ascend/examples/pytest_ut/test_cat.py new file mode 100644 index 000000000..b2d94c1ab --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_cat.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest + +@triton.jit +def cat_3d_kernel(x_ptr, y_ptr, output_ptr, # *Pointers* to input/output vector. + XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, # *Shape* of input. + dim: tl.constexpr + ): + if dim == 0: + X_out: tl.constexpr = 2 * XB + Y_out: tl.constexpr = YB + Z_out: tl.constexpr = ZB + x_idx = tl.arange(0, XB * YB * ZB) + input_x = tl.load(x_ptr + x_idx) + input_y = tl.load(y_ptr + x_idx) + + val = tl.cat(input_x, input_y, can_reorder=True) + + idx = tl.arange(0, X_out * Y_out * Z_out) + tl.store(output_ptr + idx, val) + + elif dim == 1: + X_out: tl.constexpr = XB + Y_out: tl.constexpr = 2 * YB + Z_out: tl.constexpr = ZB + for idx in range(X_out * Y_out * Z_out): + i = idx // (Y_out * Z_out) + remainder = idx % (Y_out * Z_out) + j = remainder // Z_out + k = remainder % Z_out + + if j < YB: + val = tl.load(x_ptr + i * YB * ZB + j * ZB + k) + else: + val = tl.load(y_ptr + i * YB * ZB + (j - YB) * ZB + k) + + tl.store(output_ptr + idx, val) + + elif dim == 2: + X_out: tl.constexpr = XB + Y_out: tl.constexpr = YB + Z_out: tl.constexpr = 2 * ZB + for idx in range(X_out * Y_out * Z_out): + i = idx // (Y_out * Z_out) + remainder = idx % (Y_out * Z_out) + j = remainder // Z_out + k = remainder % Z_out + + if k < ZB: + val = tl.load(x_ptr + i * YB * ZB + j * ZB + k) + else: + val = tl.load(y_ptr + i * YB * ZB + j * ZB + (k - ZB)) + + tl.store(output_ptr + idx, val) + +def cat_3d(x1: torch.Tensor, + x2: torch.Tensor, + dim: int): + assert x1.dim() == 3 and x2.dim() == 3, "Inputs must be 3D tensors" + if dim < 0: + dim += 3 + assert dim in (0, 1, 2), "Only dim=[-3, 2] supported" + assert x1.shape[0] == x2.shape[0] and x1.shape[1] == x2.shape[1] and x1.shape[2] == x2.shape[2], \ + "tl.cat only support tensors of same shape" + XB, YB, ZB = x1.shape + if dim == 0: + output_shape = (2 * XB, YB, ZB) + elif dim == 1: + output_shape = (XB, 2 * YB, ZB) + elif dim == 2: + output_shape = (XB, YB, 2 * ZB) + + output = torch.empty(output_shape, dtype=x1.dtype, device=x1.device) + + cat_3d_kernel[1,1,1]( + x1, x2, output, + XB, YB, ZB, + dim=dim + ) + return output + +def test_cat(): + params_list = \ + [ + ('float32', torch.float32, 2, 256, 16, 0), + ('float32', torch.float32, 8, 8, 4, 1), + ('float16', torch.float16, 2, 256, 16, 2), + ('float16', torch.float16, 8, 8, 4, -3), + ('int8', torch.int8, 2, 256, 16, -2), + ('int8', torch.int8, 8, 8, 4, -1), + ] + + for param in params_list: + [para_type, data_type, XB, YB, ZB, dim] = param + + x = torch.full((XB, YB, ZB), 100, dtype=data_type).npu() + y = torch.full((XB, YB, ZB), 30, dtype=data_type).npu() + + out_triton = cat_3d(x, y, dim) + out_torch = torch.cat([x, y], dim=dim) + + assert torch.allclose(out_triton, out_torch) + + print("All tests passed! -> OK") diff --git a/third_party/ascend/examples/pytest_ut/test_cat_dim.py b/third_party/ascend/examples/pytest_ut/test_cat_dim.py new file mode 100644 index 000000000..4e9210ef0 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_cat_dim.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest + +@triton.jit +def fn3_dim0(output_ptr, x1_ptr, x2_ptr, x3_ptr, x1_shape: tl.constexpr, x2_shape: tl.constexpr, x3_shape: tl.constexpr): + idx_start = 0 + x1_idx = tl.arange(0,x1_shape) + X1 = tl.load(x1_ptr + x1_idx) + tl.store(output_ptr + x1_idx, X1) + + idx_start += x1_shape + x2_idx = tl.arange(0,x2_shape) + X2 = tl.load(x2_ptr + x2_idx) + tl.store(output_ptr + idx_start + x2_idx, X2) + + idx_start += x2_shape + x3_idx = tl.arange(0,x3_shape) + X3 = tl.load(x3_ptr + x3_idx) + tl.store(output_ptr + idx_start + x3_idx, X3) + + +@triton.jit +def fn4_dim1(output_ptr, x0_ptr, x1_ptr, x2_ptr, x3_ptr, dim0_len: tl.constexpr, + x0_len: tl.constexpr, x1_len: tl.constexpr, x2_len: tl.constexpr, x3_len: tl.constexpr): + + total_dim1_len = x0_len + x1_len + x2_len + x3_len + x0 = tl.load(x0_ptr + tl.arange(0, dim0_len * x0_len)) + x0 = x0.reshape(dim0_len, x0_len) + x1 = tl.load(x1_ptr + tl.arange(0, dim0_len * x1_len)) + x1 = x1.reshape(dim0_len, x1_len) + x2 = tl.load(x2_ptr + tl.arange(0, dim0_len * x2_len)) + x2 = x2.reshape(dim0_len, x2_len) + x3 = tl.load(x3_ptr + tl.arange(0, dim0_len * x3_len)) + x3 = x3.reshape(dim0_len, x3_len) + # (86,5), (86,36), (86,5) (86,4) + #torch.arange(0,86)[:,None] * 50 + torch.arange(0,36) + idx_start = 0 + nidx0 = (tl.arange(0, dim0_len)[:, None] * total_dim1_len + idx_start) + tl.arange(0, x0_len) + tl.store(output_ptr + nidx0, x0) + + idx_start += x0_len + nidx1 = (tl.arange(0, dim0_len)[:, None] * total_dim1_len + idx_start) + tl.arange(0, x1_len) + tl.store(output_ptr + nidx1, x1) + + idx_start += x1_len + nidx2 = (tl.arange(0, dim0_len)[:, None] * total_dim1_len + idx_start) + tl.arange(0, x2_len) + tl.store(output_ptr + nidx2, x2) + + idx_start += x2_len + nidx3 = (tl.arange(0, dim0_len)[:, None] * total_dim1_len + idx_start) + tl.arange(0, x3_len) + tl.store(output_ptr + nidx3, x3) + + +def cat_dim0(data_type): + x1_shape = (86, 48) + x2_shape = (8, 48) + x3_shape = (16, 48) + + x1 = torch.rand(x1_shape, dtype=data_type).npu() + x2 = torch.rand(x2_shape, dtype=data_type).npu() + x3 = torch.rand(x3_shape, dtype=data_type).npu() + res = torch.zeros((x1_shape[0] + x2_shape[0] + x3_shape[0], 48), dtype=data_type).npu() + fn3_dim0[(1,1,1)](res, x1, x2, x3, x1_shape[0]*x1_shape[1], x2_shape[0]*x2_shape[1], x3_shape[0]*x3_shape[1]) + + res_ref = torch.cat((x1, x2, x3), dim=0) + assert torch.allclose(res_ref, res, rtol=1e-03, atol=1e-03, equal_nan=True) + + +def cat_dim1(data_type): + #data_type = torch.float16 + x0_shape = (86,5) + x1_shape = (86,36) + x2_shape = (86,5) + x3_shape = (86,4) + + dim = 1 + x0 = torch.rand(x0_shape, dtype=data_type).npu() + x1 = torch.rand(x1_shape, dtype=data_type).npu() + x2 = torch.rand(x2_shape, dtype=data_type).npu() + x3 = torch.rand(x3_shape, dtype=data_type).npu() + res = torch.zeros((86, x0_shape[dim] + x1_shape[dim] + x2_shape[dim] + x3_shape[dim]), dtype=data_type).npu() + fn4_dim1[(1,1,1)](res, x0, x1, x2, x3, 86, x0_shape[1], x1_shape[1], x2_shape[1], x3_shape[1]) + #print("res_tri=", res) + + res_ref = torch.cat((x0, x1, x2, x3), dim = 1) + #print("res_ref=", res_ref) + assert torch.allclose(res_ref, res, rtol=1e-03, atol=1e-03, equal_nan=True) + +def test_cat(): + cat_dim0(torch.float16) + cat_dim1(torch.float16) + diff --git a/third_party/ascend/examples/pytest_ut/test_cdiv.py b/third_party/ascend/examples/pytest_ut/test_cdiv.py new file mode 100644 index 000000000..b7ca65cb4 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_cdiv.py @@ -0,0 +1,41 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_cdiv(x0, x1): + return torch.div(x0, x1, rounding_mode='trunc') + (x0 % x1 > 0).to(torch.int) + + +@triton.jit +def triton_cdiv(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = tl.cdiv(XBLOCK, XBLOCK_SUB) + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index, None) + tmp1 = tl.load(in_ptr1 + x_index, None) + tmp2 = tl.cdiv(tmp0, tmp1) + tl.store(out_ptr0 + x_index, tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['int32', (4096,), 1, 4096, 4096], + ]) +def test_cdiv(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + 1 + # torch结果 + torch_res = torch_cdiv(x0, x1) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_cdiv[ncore, 1, 1](x0, x1, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_ceil.py b/third_party/ascend/examples/pytest_ut/test_ceil.py new file mode 100644 index 000000000..dafed0daf --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_ceil.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common + +def torch_ceil(x0): + res = torch.ceil(x0) + return res + +@triton.jit +def triton_ceil(in_ptr0, out_ptr0, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.ceil(tmp0) + tl.store(out_ptr0 + (x0), tmp1, None) + + +@pytest.mark.parametrize('param_list', + [ + # ['float16', (2, 4096, 8), 32, 2048, 64], + ['float32', (2, 4096, 8), 32, 2048, 64], + # ['int8', (2, 4096, 8), 32, 2048, 64], + ]) +def test_ceil(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype) + y_ref = torch_ceil(x0) + tyname = test_common.get_triton_sig_typename(dtype) + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + x0 = x0.npu() + triton_ceil[ncore, 1, 1](x0, y_cal, xblock, xblock_sub, debug=True) + y_ref = y_ref.npu() + test_common.validate_cmp_with_expection(dtype, y_cal, y_ref, True) diff --git a/third_party/ascend/examples/pytest_ut/test_clamp.py b/third_party/ascend/examples/pytest_ut/test_clamp.py new file mode 100644 index 000000000..dae84ff46 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_clamp.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common +def torch_clamp_float(x0): + res = torch.clamp(x0, 0.0, 100.0) + return res + +@triton.jit +def triton_clamp_float(in_ptr0, out_ptr0, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.clamp(tmp0, 0.0, 100.0) + tl.store(out_ptr0 + (x0), tmp1, None) + +@pytest.mark.parametrize('param_list', + [ + # int原生不支持 + ['float16', (4, 4), 4, 4, 4], + ['float32', (4, 4), 4, 4, 4], + ]) +def test_clamp(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype) + y_ref = torch_clamp_float(x0) + tyname = test_common.get_triton_sig_typename(dtype) + + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + x0 = x0.npu() + triton_clamp_float[ncore, 1, 1](x0, y_cal, xblock, xblock_sub, debug=True) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_common.py b/third_party/ascend/examples/pytest_ut/test_common.py new file mode 100644 index 000000000..9b9d57dd4 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_common.py @@ -0,0 +1,164 @@ +from typing import Optional +import torch +import torch_npu +import pytest +import functools +import re + +_float_dtypes = [ + 'float32', 'float16', 'bfloat16' +] +_int_dtypes = [ + 'int32', 'int64', 'int16', 'int8' +] +_all_dtypes_no_bool = _float_dtypes + _int_dtypes +_all_dtypes = _all_dtypes_no_bool + ['bool'] +_32bit_dtypes = ['float32', 'int32'] +_16bit_dtypes = ['float16', 'bfloat16', 'int16'] + +def generate_tensor(shape, dtype): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + return torch.randn(size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': + return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int8': + return torch.randint(low=0, high=127, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'bool': + return torch.randint(low=0, high=2, size=shape).bool() + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + + +def get_triton_sig_typename(dtype): + if dtype == 'float32': + tyname = "*fp32" + elif dtype == 'int32': + tyname = "*i32" + elif dtype == 'int64': + tyname = "*i64" + elif dtype == 'float16': + tyname = "*fp16" + elif dtype == 'int16': + tyname = "*i16" + elif dtype == 'int8': + tyname = "*i8" + elif dtype == 'bool': + tyname = "*i1" + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + return tyname + +# Relative error: abs(x_ref - x_cal) / abs(x_ref) +# Absolute error: abs(x_ref - x_cal) + +# calculation type operators require different error range +# It is a stricter verification and not satisfied now, save it here +def validate_cal(dtype, y_cal, y_ref): + if dtype == 'float16': + if torch.mean(y_ref) < 0.001: + assert torch.abs(y_cal - y_ref) < 0.001, "|y_cal - y_ref| < 0.001 is required !" + else: + diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 + # all true + assert diff.all(), "Relative error is less than 0.001 !" + if dtype == 'float32': + if torch.mean(y_ref) < 0.0001: + assert torch.abs(y_cal - y_ref) < 0.0001, "|y_cal - y_ref| < 0.0001 is required !" + else: + diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.0001 + assert diff.all(), "Relative error is less than 0.001 !" + elif dtype == 'bfloat16': + diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 + assert diff.all(), "Relative error is less than 0.001 !" + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': + assert torch.equal(y_cal, y_ref) + elif dtype == 'bool': + assert torch.equal(y_cal, y_ref) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + +# moving and comparison ops require no precision error +def validate_cmp(dtype, y_cal, y_ref, overflow_mode: Optional[str] = None): + y_cal=y_cal.npu() + y_ref=y_ref.npu() + if overflow_mode == "saturate": + if dtype in ['float32', 'float16']: + min_value = -torch.finfo(dtype).min + max_value = torch.finfo(dtype).max + elif dtype in ['int32', 'int16', 'int8']: + min_value = torch.iinfo(dtype).min + max_value = torch.iinfo(dtype).max + elif dtype == 'bool': + min_value = 0 + max_value = 1 + else: + raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype)) + y_ref = torch.clamp(y_ref, min=min_value, max=max_value) + if dtype == 'float16': + torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) + elif dtype == 'bfloat16': + torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=1e-03, atol=1e-03, equal_nan=True) + elif dtype == 'float32': + torch.testing.assert_close(y_ref, y_cal, rtol=1e-04, atol=1e-04, equal_nan=True) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': + assert torch.equal(y_cal, y_ref) + elif dtype == 'bool': + assert torch.equal(y_cal, y_ref) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + +def validate_cmp_with_expection(dtype, y_cal, y_ref, expect): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + if expect: + assert torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) + else: + assert not torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': + if expect: + assert torch.equal(y_cal, y_ref) + else: + assert not torch.equal(y_cal, y_ref) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + +# Use the following pytest fixture to run one test case by only single worker. +# Refer to https://pytest-xdist.readthedocs.io/en/stable/how-to.html#making-session-scoped-fixtures-execute-only-once +@pytest.fixture(scope="function") +def pytest_runonce(worker_id, request, cache): + if (cache.get(request.node.nodeid, "none")) == "none": + cache.set(request.node.nodeid, worker_id) + else: + file_name = f"pytest_{worker_id}.txt" + with open(file_name, 'a') as file: + file.write(f"{request.node.nodeid} is already processed by {worker_id}") + return True + yield True + cache.set(request.node.nodeid, "none") + +def raises_with_match(expected_exception, match_pattern): + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + with pytest.raises(expected_exception, match=match_pattern): + return test_func(*args, **kwargs) + return wrapper + return decorator + +def capture_output(expected_output): + def decorator(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + capsys = kwargs.pop('capsys', None) + if capsys is None: + try: + capsys = pytest.fixture(capsys)() + except: + raise RuntimeError("This decorator requires pytest's capsys fixture") + test_func(capsys, *args, **kwargs) + captured = capsys.readouterr() + # pybind11::scoped_ostream_redirect captures std::cout with \x00 inserted + # for now, no idea how to eliminate \x00 from C++ side. + cleaned = re.sub(r"\x00", "", captured.out) + assert expected_output in cleaned + return wrapper + return decorator diff --git a/third_party/ascend/examples/pytest_ut/test_compile_hint.py b/third_party/ascend/examples/pytest_ut/test_compile_hint.py new file mode 100644 index 000000000..fa9a5f823 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_compile_hint.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl +import pytest +import test_common + +# eg: pytest -v test.py::test_compile_hint +############################# + + +@triton.jit +def triton_compile_hint(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.compile_hint(tmp0, "hint_a") + tl.multibuffer(tmp0, 2) + tmp2 = tmp0 + tl.compile_hint(tmp2, "hint_b", 42) + tl.compile_hint(tmp2, "hint_c", True) + tl.compile_hint(tmp2, "hint_d", [XBLOCK, XBLOCK_SUB]) + tl.store(out_ptr0 + (xindex), tmp2, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ] + ) +def test_compile_hint(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0 + y_cal = test_common.generate_tensor(shape, dtype).npu() + triton_compile_hint[(ncore, )](x0, y_cal, x0.numel(), xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_cos.py b/third_party/ascend/examples/pytest_ut/test_cos.py new file mode 100644 index 000000000..e79f1aafb --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_cos.py @@ -0,0 +1,78 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +def standard_unary(x0, dtype): + res = torch.cos(x0) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + ret = tl.cos(x) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + # (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + # (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + print(out) + + test_common.validate_cmp(sigtype, out, ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_cos_2.py b/third_party/ascend/examples/pytest_ut/test_cos_2.py new file mode 100644 index 000000000..fa419f298 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_cos_2.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common + +def torch_cos(x0): + res = torch.cos(x0) + return res + +@triton.jit +def triton_cos(in_ptr0, out_ptr0, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.cos(tmp0) + tl.store(out_ptr0 + (x0), tmp1, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 32, 2048, 64], + ]) +def test_cos(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype) + y_ref = torch_cos(x0) + tyname = test_common.get_triton_sig_typename(dtype) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + x0 = x0.npu() + triton_cos[ncore, 1, 1](x0, y_cal, xblock, xblock_sub, debug=True) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_count_dim0.py b/third_party/ascend/examples/pytest_ut/test_count_dim0.py new file mode 100644 index 000000000..055604b04 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_count_dim0.py @@ -0,0 +1,182 @@ + # -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common + +def standard_count(x0, cmp_val, dim, dtype): + res = (x0 == cmp_val).sum(dim=dim) + return res + +def standard_count_gt(x0, cmp_val, dim, dtype): + res = (x0 > cmp_val).sum(dim=dim) + return res + +def standard_count_lt(x0, cmp_val, dim, dtype): + res = (x0 < cmp_val).sum(dim=dim) + return res + +@triton.jit +def count(in_ptr0, out_ptr0, cmp_val, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:,None]) & (nmask[None,:]) + idx = mblk_idx[:,None]*N + nblk_idx[None,:] + x = tl.load(in_ptr0+idx, mask = mask, other = 0) + tmp1 = (x == cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + nblk_idx, ret, mask = nmask) + +@triton.jit +def count_gt(in_ptr0, out_ptr0, cmp_val, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:,None]) & (nmask[None,:]) + idx = mblk_idx[:,None]*N + nblk_idx[None,:] + x = tl.load(in_ptr0+idx, mask = mask, other = 0) + tmp1 = (x > cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + nblk_idx, ret, mask = nmask) + +@triton.jit +def count_lt(in_ptr0, out_ptr0, cmp_val, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:,None]) & (nmask[None,:]) + idx = mblk_idx[:,None]*N + nblk_idx[None,:] + x = tl.load(in_ptr0+idx, mask = mask, other = 0) + tmp1 = (x < cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + nblk_idx, ret, mask = nmask) + + +shapes=[ + (57,3,64,16), (57,-32,64,32), (57,37,64,64), + (64,3,64,16), (64,-32,64,32), (64,37,64,64), + (3,3,8,8), (-32,3,32,8), (37,3,64,8), + (3,1,8,8), (-32,1,32,8), (37,1,64,8) +] + +map_for_64_t = {37:(31,32),263:(107,128)} +map_for_32_t = {263:(137,256)} + + +types0 = [ + (torch.int8,'int8'), +] +@pytest.mark.parametrize('dtype, sigtype',types0) +@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL',shapes) +def test_count_eq_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): + M = (-M)//torch.tensor(0,dtype=dtype).element_size() if M<0 else M + N = (-N)//torch.tensor(0,dtype=dtype).element_size() if N<0 else N + + if sigtype == 'int64': + M = map_for_64_t[M][0] if M in map_for_64_t else M + MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL + N = map_for_64_t[N][0] if N in map_for_64_t else N + NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL + + elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': + M = map_for_32_t[M][0] if M in map_for_32_t else M + MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL + N = map_for_32_t[N][0] if N in map_for_32_t else N + NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL + + print(f"sum : ({M}, {N}) {dtype} {sigtype}") + cmp_val = 8 + x0 = test_common.generate_tensor(shape = (M,N),dtype = sigtype) + ans = standard_count(x0, cmp_val,0, dtype) + x0 = x0.npu() + print(ans) + output = torch.zeros((N,), dtype = torch.float32).npu() + count[1,1,1](x0, output, cmp_val, 0, M = M, N = N,MNUMEL = MNUMEL, NNUMEL = NNUMEL, debug = True) + print(output) + test_common.validate_cmp('float32', output, ans.to(torch.float32)) + +#------------------------------------------------------------------------------------- + +types1 = [ + (torch.float32,'float32'), + (torch.float32,'float16'), + (torch.int8,'int8'), +] +@pytest.mark.parametrize('dtype, sigtype',types1) +@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL',shapes) +def test_count_gt_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): + M = (-M)//torch.tensor(0,dtype=dtype).element_size() if M<0 else M + N = (-N)//torch.tensor(0,dtype=dtype).element_size() if N<0 else N + + if sigtype == 'int64': + M = map_for_64_t[M][0] if M in map_for_64_t else M + MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL + N = map_for_64_t[N][0] if N in map_for_64_t else N + NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL + + elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': + M = map_for_32_t[M][0] if M in map_for_32_t else M + MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL + N = map_for_32_t[N][0] if N in map_for_32_t else N + NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL + + print(f"sum : ({M}, {N}) {dtype} {sigtype}") + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + x0 = test_common.generate_tensor(shape = (M,N),dtype = sigtype) + ans = standard_count_gt(x0, cmp_val,0, dtype) + x0 = x0.npu() + print(ans) + output = torch.zeros((N,), dtype = torch.float32).npu() + count_gt[1,1,1](x0, output, cmp_val, 0, M = M, N = N,MNUMEL = MNUMEL, NNUMEL = NNUMEL, debug = True) + print(output) + test_common.validate_cmp("float32", output, ans.to(torch.float32)) + + +shapes1=[ + (64,3,64,16), (64,-32,64,32), (64,37,64,64) +] +@pytest.mark.parametrize('dtype, sigtype',types1) +@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL',shapes1) +def test_count_lt_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): + M = (-M)//torch.tensor(0,dtype=dtype).element_size() if M<0 else M + N = (-N)//torch.tensor(0,dtype=dtype).element_size() if N<0 else N + + if sigtype == 'int64': + M = map_for_64_t[M][0] if M in map_for_64_t else M + MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL + N = map_for_64_t[N][0] if N in map_for_64_t else N + NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL + + elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': + M = map_for_32_t[M][0] if M in map_for_32_t else M + MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL + N = map_for_32_t[N][0] if N in map_for_32_t else N + NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL + + print(f"sum : ({M}, {N}) {dtype} {sigtype}") + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + x0 = test_common.generate_tensor(shape = (M,N),dtype = sigtype) + ans = standard_count_lt(x0, cmp_val,0, dtype) + x0 = x0.npu() + print(ans) + output = torch.zeros((N,), dtype = torch.float32).npu() + count_lt[1,1,1](x0, output, cmp_val, 0, M = M, N = N,MNUMEL = MNUMEL, NNUMEL = NNUMEL, debug = True) + print(output) + test_common.validate_cmp("float32", output, ans.to(torch.float32)) diff --git a/third_party/ascend/examples/pytest_ut/test_count_dim1.py b/third_party/ascend/examples/pytest_ut/test_count_dim1.py new file mode 100644 index 000000000..0e41232d9 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_count_dim1.py @@ -0,0 +1,190 @@ + # -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common + +def standard_count(x0, cmp_val, dim, dtype): + res = (x0 == cmp_val).sum(dim=dim) + return res + +def standard_count_gt(x0, cmp_val, dim, dtype): + res = (x0 > cmp_val).sum(dim=dim) + return res + +def standard_count_lt(x0, cmp_val, dim, dtype): + res = (x0 < cmp_val).sum(dim=dim) + return res + +@triton.jit +def count(in_ptr0, out_ptr0, cmp_val, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:,None]) & (nmask[None,:]) + idx = mblk_idx[:,None]*N + nblk_idx[None,:] + x = tl.load(in_ptr0+idx, mask = mask, other = 0) + tmp1 = (x == cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + mblk_idx, ret, mask = mmask) + +@triton.jit +def count_gt(in_ptr0, out_ptr0, cmp_val, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:,None]) & (nmask[None,:]) + idx = mblk_idx[:,None]*N + nblk_idx[None,:] + x = tl.load(in_ptr0+idx, mask = mask, other = 0) + tmp1 = (x > cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + mblk_idx, ret, mask = mmask) + +@triton.jit +def count_lt(in_ptr0, out_ptr0, cmp_val, dim : tl.constexpr, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + mmask = mblk_idx < M + nmask = nblk_idx < N + mask = (mmask[:,None]) & (nmask[None,:]) + idx = mblk_idx[:,None]*N + nblk_idx[None,:] + x = tl.load(in_ptr0+idx, mask = mask, other = 0) + tmp1 = (x < cmp_val) + tmp2 = tmp1.to(tl.float32) + ret = tl.sum(tmp2, dim) + tl.store(out_ptr0 + mblk_idx, ret, mask = mmask) + + +# if shape axis = 32/256 , then actual shape = axis/element_size() + +shapes=[ + (57,3,64,16), (57,-32,64,32), + (64,3,64,16), (64,-32,64,32), + (3,3,8,8), (-32,3,32,8), (37,3,64,8), + (3,1,8,8), (-32,1,32,8), (37,1,64,8) +] + +map_for_64_t = {37:(31,32),263:(107,128)} +map_for_32_t = {263:(137,256)} + + +types0 = [ + (torch.int8,'int8'), +] +@pytest.mark.parametrize('dtype, sigtype',types0) +@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL',shapes) +def test_count_eq_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): + M = (-M)//torch.tensor(0,dtype=dtype).element_size() if M<0 else M + N = (-N)//torch.tensor(0,dtype=dtype).element_size() if N<0 else N + + if sigtype == 'int64': + M = map_for_64_t[M][0] if M in map_for_64_t else M + MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL + N = map_for_64_t[N][0] if N in map_for_64_t else N + NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL + + elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': + M = map_for_32_t[M][0] if M in map_for_32_t else M + MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL + N = map_for_32_t[N][0] if N in map_for_32_t else N + NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL + + print(f"sum : ({M}, {N}) {dtype} {sigtype}") + cmp_val = 8 + x0 = test_common.generate_tensor(shape = (M,N),dtype = sigtype) + ans = standard_count(x0, cmp_val, 1, dtype) + x0 = x0.npu() + print(ans) + output = torch.zeros((M,), dtype = torch.float32).npu() + count[1,1,1](x0, output, cmp_val, 1, M = M, N = N,MNUMEL = MNUMEL, NNUMEL = NNUMEL, debug = True) + print(output) + test_common.validate_cmp('float32',output,ans.to(torch.float32)) + +#------------------------------------------------------------------------------------- + +types1 = [ + (torch.float32,'float32'), + (torch.float32,'float16'), + (torch.int8,'int8'), +] +@pytest.mark.parametrize('dtype, sigtype',types1) +@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL',shapes) +def test_count_gt_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): + M = (-M)//torch.tensor(0,dtype=dtype).element_size() if M<0 else M + N = (-N)//torch.tensor(0,dtype=dtype).element_size() if N<0 else N + + if sigtype == 'int64': + M = map_for_64_t[M][0] if M in map_for_64_t else M + MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL + N = map_for_64_t[N][0] if N in map_for_64_t else N + NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL + + elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': + M = map_for_32_t[M][0] if M in map_for_32_t else M + MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL + N = map_for_32_t[N][0] if N in map_for_32_t else N + NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL + + print(f"sum : ({M}, {N}) {dtype} {sigtype}") + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + x0 = test_common.generate_tensor(shape = (M,N),dtype = sigtype) + ans = standard_count_gt(x0, cmp_val, 1, dtype) + x0 = x0.npu() + print(ans) + output = torch.zeros((M,), dtype = torch.float32).npu() + count_gt[1,1,1](x0, output, cmp_val, 1, M = M, N = N,MNUMEL = MNUMEL, NNUMEL = NNUMEL, debug = True) + print(output) + test_common.validate_cmp('float32',output,ans.to(torch.float32)) + + + +types2 = [ + (torch.int8,'int8') +] +shapes2=[ + (57,-32,64,32), (64,-32,64,32) +] + +@pytest.mark.parametrize('dtype, sigtype',types2) +@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL',shapes2) +def test_count_lt_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): + M = (-M)//torch.tensor(0,dtype=dtype).element_size() if M<0 else M + N = (-N)//torch.tensor(0,dtype=dtype).element_size() if N<0 else N + + if sigtype == 'int64': + M = map_for_64_t[M][0] if M in map_for_64_t else M + MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL + N = map_for_64_t[N][0] if N in map_for_64_t else N + NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL + + elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': + M = map_for_32_t[M][0] if M in map_for_32_t else M + MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL + N = map_for_32_t[N][0] if N in map_for_32_t else N + NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL + + print(f"sum : ({M}, {N}) {dtype} {sigtype}") + if dtype == torch.int8: + cmp_val = 8 + else: + cmp_val = 0.5 + x0 = test_common.generate_tensor(shape = (M,N),dtype = sigtype) + ans = standard_count_lt(x0, cmp_val, 1, dtype) + x0 = x0.npu() + print(ans) + output = torch.zeros((M,), dtype = torch.float32).npu() + count_lt[1,1,1](x0, output, cmp_val, 1, M = M, N = N,MNUMEL = MNUMEL, NNUMEL = NNUMEL, debug = True) + print(output) + test_common.validate_cmp('float32',output,ans.to(torch.float32)) diff --git a/third_party/ascend/examples/pytest_ut/test_cumprod.py b/third_party/ascend/examples/pytest_ut/test_cumprod.py new file mode 100644 index 000000000..40c8250f9 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_cumprod.py @@ -0,0 +1,81 @@ +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +from triton.runtime.libentry import libentry + +from test_common import _all_dtypes_no_bool, validate_cmp + + +def torch_func(x, dim, reverse): + is_bf16 = x.dtype == torch.bfloat16 + if is_bf16: + x = x.to(torch.float32) + if reverse: + x = torch.flip(x, [dim]) + res = torch.cumprod(x, dim=dim) + if is_bf16: + res = res.to(torch.bfloat16) + return res + + +@libentry() +@triton.jit +def triton_kernel( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx = idx_x[:, None] * numel_r + idx_r[None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.cumprod(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +def triton_func(x, dim, reverse): + res = torch.empty_like(x) + triton_kernel[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1] + ) + return res + + +def cumprod_generate_tensor(shape, dtype): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + return torch.rand(size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': + return torch.randint(low=0, high=3, size=shape, dtype=eval('torch.' + dtype)) + elif dtype == 'int8': + return torch.randint(low=0, high=3, size=shape, dtype=eval('torch.' + dtype)) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +# dtype=int8, reverse=True not support; +not_support_dtype = {'int8', 'bool'} +support_dtypes = [dtype for dtype in _all_dtypes_no_bool if dtype not in not_support_dtype] + + +@pytest.mark.parametrize("dtype", support_dtypes) +@pytest.mark.parametrize("shape", [(7, 23)]) +@pytest.mark.parametrize("dim", [0, 1]) +@pytest.mark.parametrize("reverse", [False]) +def test_cumprod(dtype, shape, dim, reverse): + x0 = cumprod_generate_tensor(shape=shape, dtype=dtype).npu() + triton_cal = triton_func(x0, dim, reverse) + torch_ref = torch_func(x0, dim, reverse) + validate_cmp(dtype, torch_ref, triton_cal) diff --git a/third_party/ascend/examples/pytest_ut/test_cumsum.py b/third_party/ascend/examples/pytest_ut/test_cumsum.py new file mode 100644 index 000000000..e94ff71c3 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_cumsum.py @@ -0,0 +1,68 @@ +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +from triton.runtime.libentry import libentry + +from test_common import _all_dtypes_no_bool, generate_tensor, validate_cmp + + +def torch_func(x, dim, reverse): + if reverse: + x = torch.flip(x, [dim]) + res = torch.cumsum(x, dim=dim) + return res + + +@libentry() +@triton.jit +def triton_kernel( + out_ptr0, + in_ptr0, + dim: tl.constexpr, + reverse: tl.constexpr, + numel_x: tl.constexpr, + numel_r: tl.constexpr, + XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr, +): + tl.static_assert( + numel_x == XBLOCK, "numel_x must be equal to XBLOCK in this kernel" + ) + tl.static_assert( + numel_r == RBLOCK, "numel_r must be equal to RBLOCK in this kernel" + ) + idx_x = tl.arange(0, XBLOCK) + idx_r = tl.arange(0, RBLOCK) + idx = idx_x[:, None] * numel_r + idx_r[None, :] + x = tl.load(in_ptr0 + idx) + ret = tl.cumsum(x, axis=dim, reverse=reverse) + tl.store(out_ptr0 + idx, ret) + + +def triton_func(x, dim, reverse): + res = torch.empty_like(x) + triton_kernel[1, 1, 1]( + res, x, dim, reverse, x.shape[0], x.shape[1], x.shape[0], x.shape[1] + ) + return res + + +# dtype=int8, reverse=True not support; +not_support_dtype = {'int8', 'bool'} +support_dtypes = [dtype for dtype in _all_dtypes_no_bool if dtype not in not_support_dtype] + + +@pytest.mark.parametrize("dtype", support_dtypes) +@pytest.mark.parametrize("shape", [(7, 23)]) +@pytest.mark.parametrize("dim", [0, 1]) +@pytest.mark.parametrize("reverse", [False]) +def test_cumsum(dtype, shape, dim, reverse): + x0 = generate_tensor(shape=shape, dtype=dtype).npu() + triton_cal = triton_func(x0, dim, reverse) + torch_dtype = eval('torch.' + dtype) + if torch_dtype == torch.float16 or torch_dtype == torch.float32: + x0 = x0.to(torch.float32) + torch_ref = torch_func(x0, dim, reverse).to(torch_dtype) + validate_cmp(dtype, torch_ref, triton_cal) diff --git a/third_party/ascend/examples/pytest_ut/test_debug_barrier.py b/third_party/ascend/examples/pytest_ut/test_debug_barrier.py new file mode 100644 index 000000000..38b03e54f --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_debug_barrier.py @@ -0,0 +1,41 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + +def torch_pointwise(x0, x1): + res = x0 - x1 + return res + + +@triton.jit +def triton_sub(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 - tmp1 + tl.debug_barrier() + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + triton_sub[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_device_print.py b/third_party/ascend/examples/pytest_ut/test_device_print.py new file mode 100644 index 000000000..a3dd460ea --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_device_print.py @@ -0,0 +1,136 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common +import os + + +os.environ["TRITON_DEVICE_PRINT"] = "1" +os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" +shape = (8,) +XS = 8 +XVALS_INT = [0, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + torch.iinfo(torch.int16).min, + torch.iinfo(torch.int16).max, + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + torch.iinfo(torch.int32).max+1] +XVALS_FP = [0, + torch.finfo(torch.float32).eps, + torch.finfo(torch.float16).eps, + torch.finfo(torch.bfloat16).eps, + torch.finfo(torch.float32).max, + torch.finfo(torch.float16).max, + torch.finfo(torch.bfloat16).max, + 1] + + +def torch_func(x0, x1): + res = x0 + x1 + return res + + +@triton.jit +def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr): + idx = tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 + tmp1 + tl.device_print("OUTPUT = ", tmp2) + tl.store(out_ptr0 + idx, tmp2) + + +def triton_func(x0, x1, XS): + out = torch.empty_like(x0) + triton_kernel[1, 1, 1](out, x0, x1, XS) + return out + + +@pytest.mark.skip(reason="waiting for bishengir-compile to support") +@pytest.mark.parametrize('sigtype', ['int64']) +def test_device_print_int64(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['int32']) +def test_device_print_int32(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['int16']) +def test_device_print_int16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['int8']) +def test_device_print_int8(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['float32']) +def test_device_print_fp32(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_FP[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.parametrize('sigtype', ['float16']) +def test_device_print_fp16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_FP[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.skip(reason="waiting for bishengir-compile to support") +@pytest.mark.parametrize('sigtype', ['bfloat16']) +def test_device_print_bf16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype = dtype).npu() + x1 = torch.ones(shape, dtype = dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_FP[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_device_print_script.py b/third_party/ascend/examples/pytest_ut/test_device_print_script.py new file mode 100644 index 000000000..fc1e7672e --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_device_print_script.py @@ -0,0 +1,97 @@ +import os +import subprocess +import re +import pytest + + +# 执行test_device_print.py -> 生成 test_device_print.log, 以便执行后续验证 +def run_test_device_print_py(test_name, log_name): + testfile_path = os.path.join(os.getcwd(), test_name) + logfile_path = os.path.join(os.getcwd(), log_name) + + with open(logfile_path, 'w') as f: + try: + subprocess.run(["pytest", testfile_path], stdout=f, stderr=subprocess.STDOUT, check=True) + print(f"Run 【{test_name}】 successfully!") + except Exception as e: + print(f"Run 【{test_name}】 unsuccessfully: ", e) + + +def assert_close(expected_output, logfile): + # 读回日志内容 + with open(logfile, "r", encoding="utf-8") as f: + raw = f.read() + cleaned = re.sub(r"\x00", "", raw) + + # 在文件内容里做检查 + assert expected_output in cleaned, f"Expected '{expected_output}' not found in log file." + + +@pytest.mark.skip(reason="waiting for TA to support") +def test_device_print_int8(): + expected_output = "0,-128,127,0,-1,0,-1,0" + test_name = "test_device_print.py::test_device_print_int8[int8]" + log_name = "test_device_print_int8.log" + logfile = os.path.join(os.getcwd(), log_name) + run_test_device_print_py(test_name, logfile) + assert_close(expected_output, logfile) + + +@pytest.mark.skip(reason="waiting for TA to support") +def test_device_print_int16(): + expected_output = "0,-128,127,-32768,32767,0,-1,0" + test_name = "test_device_print.py::test_device_print_int16[int16]" + log_name = "test_device_print_int16.log" + logfile = os.path.join(os.getcwd(), log_name) + run_test_device_print_py(test_name, logfile) + assert_close(expected_output, logfile) + + +@pytest.mark.skip(reason="waiting for TA to support") +def test_device_print_int32(): + expected_output = "0,-128,127,-32768,32767,-2147483648,2147483647,-2147483648" + test_name = "test_device_print.py::test_device_print_int32[int32]" + log_name = "test_device_print_int32.log" + logfile = os.path.join(os.getcwd(), log_name) + run_test_device_print_py(test_name, logfile) + assert_close(expected_output, logfile) + + +@pytest.mark.skip(reason="waiting for compiler to support") +def test_device_print_int64(): + expected_output = "???" + test_name = "test_device_print.py::test_device_print_int64[int64]" + log_name = "test_device_print_int64.log" + logfile = os.path.join(os.getcwd(), log_name) + run_test_device_print_py(test_name, logfile) + assert_close(expected_output, logfile) + + +@pytest.mark.skip(reason="waiting for TA to support") +def test_device_print_fp16(): + expected_output = "0.000000,0.000000,0.000977,0.007812,inf,65504.000000,inf,1.000000" + test_name = "test_device_print.py::test_device_print_fp16[float16]" + log_name = "test_device_print_fp16.log" + logfile = os.path.join(os.getcwd(), log_name) + run_test_device_print_py(test_name, logfile) + assert_close(expected_output, logfile) + + +@pytest.mark.skip(reason="waiting for TA to support") +def test_device_print_fp32(): + expected_output = "0.000000,0.000000,0.000977,0.007812,340282346638528859811704183484516925440.000000,65504.000000,338953138925153547590470800371487866880.000000,1.000000" + test_name = "test_device_print.py::test_device_print_fp32[float32]" + log_name = "test_device_print_fp16.log" + logfile = os.path.join(os.getcwd(), log_name) + run_test_device_print_py(test_name, logfile) + assert_close(expected_output, logfile) + + +@pytest.mark.skip(reason="waiting for compiler to support") +def test_device_print_bf16(): + expected_output = "???" + test_name = "test_device_print.py::test_device_print_bf16[bfloat16]" + log_name = "test_device_print_bf16.log" + logfile = os.path.join(os.getcwd(), log_name) + run_test_device_print_py(test_name, logfile) + assert_close(expected_output, logfile) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_discrete_mask_loadstore.py b/third_party/ascend/examples/pytest_ut/test_discrete_mask_loadstore.py new file mode 100644 index 000000000..311e6a302 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_discrete_mask_loadstore.py @@ -0,0 +1,47 @@ +import torch +import triton +import triton.language as tl +import torch_npu +import pytest + + +@triton.jit +def simple_discrete_mask_load_kernel( + in_ptr, + out_ptr, + M: tl.constexpr, + N: tl.constexpr, +): + pid = tl.program_id(0) + col_offs = tl.arange(0, N) + even_col_offs = tl.arange(0, N // 2) * 2 + even_col_mask = even_col_offs < N + row_offs = tl.arange(0, M) + row_mask = row_offs < M + in_even_ptr = in_ptr + row_offs[:, None] * N + even_col_offs[None, :] + in_odd_ptr = in_ptr + row_offs[:, None] * N + even_col_offs[None, :] + 1 + even_data = tl.load(in_even_ptr, mask=row_mask[:, None] & even_col_mask[None, :], other=0.0) + odd_data = tl.load(in_odd_ptr) + rotated_data = tl.interleave(-odd_data, even_data) + out_ptr = out_ptr + row_offs[:, None] * N + col_offs[None, :] + tl.store(out_ptr, rotated_data) + + +@pytest.mark.parametrize("M", [(4)]) +@pytest.mark.parametrize("N", [(8)]) +def test_discrete_mask_load_store(M, N): + input_tensor = torch.arange(M * N, dtype=torch.float16, device='npu').reshape(M, N) + output_tensor = torch.empty_like(input_tensor) + grid = (1,) + simple_discrete_mask_load_kernel[grid]( + input_tensor, + output_tensor, + M=M, + N=N, + ) + even_cols = input_tensor[:, 0::2] + odd_cols = input_tensor[:, 1::2] + ref_output = torch.empty_like(input_tensor) + ref_output[:, 0::2] = -odd_cols + ref_output[:, 1::2] = even_cols + assert torch.allclose(output_tensor.float(), ref_output.float()) diff --git a/third_party/ascend/examples/pytest_ut/test_div.py b/third_party/ascend/examples/pytest_ut/test_div.py new file mode 100644 index 000000000..f57506243 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_div.py @@ -0,0 +1,45 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + +def torch_pointwise(x0, x1): + res = x0 / x1 + return res + + +@triton.jit +def triton_div(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 / tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + if dtype == 'int8': + dtype = 'float32' + x1 = x1.masked_fill(x1 == 0, 1) + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + triton_div[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_dot_scaled.py b/third_party/ascend/examples/pytest_ut/test_dot_scaled.py new file mode 100644 index 000000000..ac5af3390 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_dot_scaled.py @@ -0,0 +1,128 @@ +import contextlib +import itertools +import re +import math +import textwrap +import os +import inspect +import pathlib + +import numpy as np +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +from numpy.random import RandomState +from triton.language.extra import libdevice + + +@pytest.mark.parametrize("M, N, K, rhs_scale, normal_type, acc_num, num_warps", + [(M, N, K, rhs_scale, normal_type, acc_num, 4) + for M, N, K in itertools.product([32, 64], [32, 64], [32]) + for rhs_scale in [False, True] + for normal_type in ["bf16", "fp16"] + for acc_num in [None, 1, 2]]) +def test_scaled_dot(M, N, K, rhs_scale, normal_type, num_warps, acc_num): + device = "npu" + + @triton.jit + def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, + type_b: tl.constexpr, acc_num: tl.constexpr): + + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K + a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, + PACKED_BLOCK_K_A)[None, :] * stride_a1 + b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, + BLOCK_N)[None, :] * stride_b1 + + a = tl.load(a_ptr) + b = tl.load(b_ptr) + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + if a_scale is not None: + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + a_scale = tl.load(scale_a_ptr) + if b_scale is not None: + scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, + SCALE_BLOCK_K)[None, :] + b_scale = tl.load(scale_b_ptr) + accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, out_dtype=tl.float32) + if acc_num is not None: + for _ in range(acc_num): + accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, + out_dtype=tl.float32) + + out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + tl.store(out_ptr, accumulator.to(a.dtype)) + + # The max exponent we use to initialize data in the x/y and associated scale tensor to avoid + # overflow when scaling. + comp_dtype_max_exp = 6 if normal_type == "fp16" else 15 + + torch.manual_seed(0) + + def make_arg(shape, ty): + if ty == "bf16" or ty == "fp16": + comp_dtype = torch.float16 if ty == "fp16" else torch.bfloat16 + ret = torch.randn(shape, dtype=comp_dtype, device=device) + # Clamp to avoid relative error issues + ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1) + else: + ret = torch.randint(256, shape, dtype=torch.int8, device=device) + return ret + + type_a = normal_type + type_b = type_a + + x = make_arg((M, K), type_a) + y = make_arg((K, N), type_b) + + min_scale, max_scale = (0, 142) if type_a == torch.bfloat16 else (124, 131) + scale_x = torch.randint(min_scale - 128, max_scale - 127, (M, K // 32), dtype=torch.int8, device=device) + min_scale, max_scale = (0, 142) if type_b == torch.bfloat16 else (124, 131) + scale_y = torch.randint(min_scale - 128, max_scale - 127, (N, K // 32), dtype=torch.int8, device=device) + + if not rhs_scale: + scale_y = None + + def golden_ref(x, scale_x, y, scale_y): + shape_expand_x = x.shape[-1] // scale_x.shape[-1] + if x.dtype == torch.bfloat16: + upscale_x = scale_x.repeat_interleave(shape_expand_x, dim=1).to(torch.int16) + upscale_x = (upscale_x + 127 << 7).view(torch.bfloat16) + else: + scale_fp32 = scale_x.repeat_interleave(shape_expand_x, dim=1).to(torch.int32) + scale_fp32 = (scale_fp32 + 127 << 23).view(torch.float32) + upscale_x = scale_fp32.to(torch.float16) + upscale_y = None + if scale_y is None: + upscale_y = torch.ones_like(y) + else: + scale_y = scale_y.T + shape_expand_y = y.shape[0] // scale_y.shape[0] + if y.dtype == torch.bfloat16: + upscale_y = scale_y.repeat_interleave(shape_expand_y, dim=0).to(torch.int16) + upscale_y = (upscale_y + 127 << 7).view(torch.bfloat16) + else: + scale_fp32 = scale_y.repeat_interleave(shape_expand_y, dim=0).to(torch.int32) + scale_fp32 = (scale_fp32 + 127 << 23).view(torch.float32) + upscale_y = scale_fp32.to(torch.float16) + ret = torch.matmul(x * upscale_x, y * upscale_y) + return ret + + kernel_kwargs = {"num_warps": num_warps} + z = x.new_empty((M, N), dtype=x.dtype) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, + acc_num, **kernel_kwargs) + z_ref = golden_ref(x, scale_x, y, scale_y) + if acc_num is not None: + z_ref = z_ref * (acc_num + 1) + + atol = 1e-5 + rtol = 1e-2 + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) diff --git a/third_party/ascend/examples/pytest_ut/test_downgrade.py b/third_party/ascend/examples/pytest_ut/test_downgrade.py new file mode 100644 index 000000000..63b99a12b --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_downgrade.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest +from triton.backends.ascend.utils import downgrade_llir, _downgrade_mem_attrs, _downgrade_stacksaverestore_intrinsics + +@pytest.mark.parametrize("new_attr,legacy_attrs", [ + ("memory(none)" , ["readnone"]), + ("memory(read)" , ["readonly"]), + ("memory(write)" , ["writeonly"]), + ("memory(readwrite)" , []), + ("memory(argmem: read)" , ["readonly", "argmemonly"]), + ("memory(argmem: read, inaccessiblemem: write)" , ["inaccessiblemem_or_argmemonly"]), + ("memory(read, argmem: readwrite)" , []), + ("memory(readwrite, argmem: none)" , []), +]) +def test_mem_attrs(new_attr, legacy_attrs): + assert _downgrade_mem_attrs(new_attr).strip().split() == legacy_attrs + +@pytest.mark.parametrize("new_intr,legacy_intr", [ + ("declare ptr @llvm.stacksave.p0()" , "declare ptr @llvm.stacksave()"), + ("declare ptr addrspace(5) @llvm.stacksave.p5()" , "declare ptr addrspace(5) @llvm.stacksave()"), + ("declare void @llvm.stackrestore.p0(ptr %ptr)" , "declare void @llvm.stackrestore(ptr %ptr)"), + ("declare void @llvm.stackrestore.p5(ptr addrspace(5) %ptr)" , "declare void @llvm.stackrestore(ptr addrspace(5) %ptr)"), + ("%53 = call ptr @llvm.stacksave.p0()" , "%53 = call ptr @llvm.stacksave()"), + ("call void @llvm.stackrestore.p0(ptr %53)" , "call void @llvm.stackrestore(ptr %53)"), +]) +def test_stacksaverestore_intrinsics(new_intr, legacy_intr): + assert _downgrade_stacksaverestore_intrinsics(new_intr).strip() == legacy_intr diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_ceil.py b/third_party/ascend/examples/pytest_ut/test_elementwise_ceil.py new file mode 100644 index 000000000..f7479fb98 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_elementwise_ceil.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time +import test_common +import os +import shutil + +import torch +import torch_npu + + +def standard_ceil(x0): + res = torch.ceil(x0) + return res + + +@triton.jit +def triton_ceil(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.ceil(x) + tl.store(out_ptr0 + idx_block, res, mask=mask) + + +types = [ + (torch.float32, 'float32'), +] + +# if shape axis = 32/256 , then actual shape = axis/element_size() +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + +ops = [ + ('ceil', triton_ceil, standard_ceil), +] + + +@pytest.mark.parametrize('opName, tritonOp, standOp', ops) +@pytest.mark.parametrize('dtype, sigtype', types) +@pytest.mark.parametrize('N, NUMEL', shapes) +def test_elementwise_common(opName, tritonOp, standOp, dtype, sigtype, N, NUMEL): + torch.manual_seed(0) + torch_npu.npu.utils.set_device(0) + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == 'int64': + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standOp(x0) + x0 = x0.npu() + + output = torch.zeros((N,), dtype=dtype).npu() + tritonOp[1, 1, 1](x0, output, N=N, NUMEL=NUMEL, debug=True) + test_common.validate_cmp(sigtype, output, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_clip.py b/third_party/ascend/examples/pytest_ut/test_elementwise_clip.py new file mode 100644 index 000000000..122bf03ae --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_elementwise_clip.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time +import test_common +import os +import shutil + +import torch +import torch_npu + +def standard_clamp(x0): + res = torch.clamp(x0, min=-10, max=10) + return res +@triton.jit +def triton_clamp(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.clamp(x, -10, 10) + tl.store(out_ptr0 + idx_block, res, mask=mask) + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + (torch.bfloat16, 'bfloat16'), +] + +# if shape axis = 32/256 , then actual shape = axis/element_size() +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + +ops = [ + ('clamp', triton_clamp, standard_clamp), +] + + +@pytest.mark.parametrize('opName, tritonOp, standOp', ops) +@pytest.mark.parametrize('dtype, sigtype', types) +@pytest.mark.parametrize('N, NUMEL', shapes) +def test_elementwise_common(opName, tritonOp, standOp, dtype, sigtype, N, NUMEL): + torch.manual_seed(0) + torch_npu.npu.utils.set_device(0) + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == 'int64': + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standOp(x0) + x0 = x0.npu() + + output = torch.zeros((N,), dtype=dtype).npu() + tritonOp[1, 1, 1](x0, output, N=N, NUMEL=NUMEL, debug=True) + test_common.validate_cmp(sigtype, output, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_f2i.py b/third_party/ascend/examples/pytest_ut/test_elementwise_f2i.py new file mode 100644 index 000000000..8cf3088b8 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_elementwise_f2i.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time +import test_common +import os +import shutil + +import torch +import torch_npu + +def standard_f2i32(x0): + res = x0.to(torch.int32) + return res + +def standard_f2i8(x0): + res = x0.to(torch.int8) + return res + +def standard_f2i16(x0): + res = x0.to(torch.int16) + return res + +def standard_f2i64(x0): + res = x0.to(torch.int64) + return res + +@triton.jit +def triton_f2i8(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.cast(x, tl.int8) + tl.store(out_ptr0 + idx_block, res, mask=mask) + +@triton.jit +def triton_f2i16(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.cast(x, tl.int16) + tl.store(out_ptr0 + idx_block, res, mask=mask) + +@triton.jit +def triton_f2i32(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.cast(x, tl.int32) + tl.store(out_ptr0 + idx_block, res, mask=mask) + +@triton.jit +def triton_f2i64(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.cast(x, tl.int64) + tl.store(out_ptr0 + idx_block, res, mask=mask) + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + (torch.bfloat16, 'bfloat16'), +] + +# if shape axis = 32/256 , then actual shape = axis/element_size() +shapes = [ + (3, 32), +] + +map_for_64_t = {37: 31} + +ops = [ + ('f2i8', triton_f2i8, standard_f2i8, 'int8'), + ('f2i16', triton_f2i16, standard_f2i16, 'int16'), + ('f2i32', triton_f2i32, standard_f2i32, 'int32'), + ('f2i64', triton_f2i64, standard_f2i64, 'int64'), +] + + +def continue_func(opName, d_type): + if 'f2i' in opName and 'int' in d_type: + return True + + +@pytest.mark.parametrize('opName, tritonOp, standOp, dst_sigtype', ops) +@pytest.mark.parametrize('dtype, sigtype', types) +@pytest.mark.parametrize('N, NUMEL', shapes) +def test_elementwise_common(opName, tritonOp, standOp, dst_sigtype, dtype, sigtype, N, NUMEL): + if continue_func(opName, sigtype): + return + + torch_npu.npu.utils.set_device(0) + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == 'int64': + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standOp(x0) + x0 = x0.npu() + + output = test_common.generate_tensor(shape=(N,), dtype=dst_sigtype).npu() + tritonOp[1, 1, 1](x0, output, N=N, NUMEL=NUMEL, debug=True) + + test_common.validate_cmp(dst_sigtype, output, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_floor.py b/third_party/ascend/examples/pytest_ut/test_elementwise_floor.py new file mode 100644 index 000000000..c06ddce28 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_elementwise_floor.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time +import test_common +import os +import shutil + +import torch +import torch_npu + + +def standard_floor(x0): + res = torch.floor(x0) + return res + + +@triton.jit +def triton_floor(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.floor(x) + tl.store(out_ptr0 + idx_block, res, mask=mask) + + +types = [ + (torch.float32, 'float32'), +] + +# if shape axis = 32/256 , then actual shape = axis/element_size() +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + +ops = [ + ('floor', triton_floor, standard_floor), +] + + +@pytest.mark.parametrize('opName, tritonOp, standOp', ops) +@pytest.mark.parametrize('dtype, sigtype', types) +@pytest.mark.parametrize('N, NUMEL', shapes) +def test_elementwise_common(opName, tritonOp, standOp, dtype, sigtype, N, NUMEL): + torch.manual_seed(0) + torch_npu.npu.utils.set_device(0) + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == 'int64': + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standOp(x0) + x0 = x0.npu() + + output = torch.zeros((N,), dtype=dtype).npu() + tritonOp[1, 1, 1](x0, output, N=N, NUMEL=NUMEL, debug=True) + test_common.validate_cmp(sigtype, output, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_i2f.py b/third_party/ascend/examples/pytest_ut/test_elementwise_i2f.py new file mode 100644 index 000000000..9f85f8642 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_elementwise_i2f.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time +import test_common +import os +import shutil + +import torch +import torch_npu + + +def standard_i2f_float32(x0): + res = x0.to(torch.float32) + return res + + +def standard_i2f_float16(x0): + res = x0.to(torch.float16) + return res + + +def standard_i2f_bfloat16(x0): + res = x0.to(torch.bfloat16) + return res + + +@triton.jit +def triton_i2f_float32(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.cast(x, tl.float32) + tl.store(out_ptr0 + idx_block, res, mask=mask) + + +@triton.jit +def triton_i2f_float16(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.cast(x, tl.float16) + tl.store(out_ptr0 + idx_block, res, mask=mask) + + +@triton.jit +def triton_i2f_bfloat16(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.cast(x, tl.bfloat16) + tl.store(out_ptr0 + idx_block, res, mask=mask) + + +types = [ + # (torch.int8, 'int8'), # TO BE FIXED i8 -> f16、bf16 + # (torch.int16, 'int16'), # TO BE FIXED i16 -> f32、bf16 + (torch.int32, 'int32'), # TO BE FIXED i32 -> f16、bf16 + # (torch.int64, 'int64'), # TO BE FIXED i64 -> bf16 +] + +# if shape axis = 32/256 , then actual shape = axis/element_size() +shapes = [ + (3, 32), +] + +map_for_64_t = {37: 31} + +ops = [ + # ('i2f16', triton_i2f_float16, standard_i2f_float16, 'float16'), + ('i2f32', triton_i2f_float32, standard_i2f_float32, 'float32'), + # ('i2fbf16', triton_i2f_bfloat16, standard_i2f_bfloat16, 'bfloat16'), +] + + +def continue_func(opName, d_type): + if 'i2f' in opName and 'float' in d_type: + return True + + +@pytest.mark.parametrize('opName, tritonOp, standOp, dst_sigtype', ops) +@pytest.mark.parametrize('dtype, sigtype', types) +@pytest.mark.parametrize('N, NUMEL', shapes) +def test_elementwise_common(opName, tritonOp, standOp, dst_sigtype, dtype, sigtype, N, NUMEL): + if continue_func(opName, sigtype): + return + + torch_npu.npu.utils.set_device(0) + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == 'int64': + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standOp(x0) + x0 = x0.npu() + + output = test_common.generate_tensor(shape=(N,), dtype=dst_sigtype).npu() + tritonOp[1, 1, 1](x0, output, N=N, NUMEL=NUMEL, debug=True) + + test_common.validate_cmp(dst_sigtype, output, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_round.py b/third_party/ascend/examples/pytest_ut/test_elementwise_round.py new file mode 100644 index 000000000..f45e67c69 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_elementwise_round.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import test_common +import os +import shutil + +import torch +import torch_npu + + +def standard_round_to_nearest_neighbor_even(x0): # TO BE FIXED round to nearest neighbor even + res = torch.round(x0) + return res + +def standard_round(x0): + res = torch.where(x0 >= 0, torch.floor(x0 + 0.5), torch.ceil(x0 - 0.5)) + return res + +@triton.jit +def triton_round(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + mask = idx_block < N + x = tl.load(in_ptr0 + idx_block, mask=mask) + res = tl.where(x >= 0, tl.floor(x.to(tl.float32) + 0.5), tl.ceil(x.to(tl.float32) - 0.5)) + tl.store(out_ptr0 + idx_block, res, mask=mask) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + (torch.bfloat16, 'bfloat16'), +] + +# if shape axis = 32/256 , then actual shape = axis/element_size() +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + +ops = [ + ('round', triton_round, standard_round), +] + + +@pytest.mark.parametrize('opName, tritonOp, standOp', ops) +@pytest.mark.parametrize('dtype, sigtype', types) +@pytest.mark.parametrize('N, NUMEL', shapes) +def test_elementwise_common(opName, tritonOp, standOp, dtype, sigtype, N, NUMEL): + torch.manual_seed(0) + torch_npu.npu.utils.set_device(0) + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == 'int64': + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standOp(x0) + x0 = x0.npu() + + output = torch.zeros((N,), dtype=dtype).npu() + tritonOp[1, 1, 1](x0, output, N=N, NUMEL=NUMEL, debug=True) + test_common.validate_cmp(sigtype, output, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_eq.py b/third_party/ascend/examples/pytest_ut/test_eq.py new file mode 100644 index 000000000..5ef4bb2cc --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_eq.py @@ -0,0 +1,55 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +def standard_binary(x0, y0): + res = x0 == y0 + return res + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x == y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + (torch.int8, 'int8'), + (torch.int16, 'int16'), + (torch.int32, 'int32'), + (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + ans = standard_binary(x0, y0) + out = torch.zeros((N,), dtype=torch.bool).npu() + triton_elementwise_binary[1, 1, 1](x0, y0, out, N, NUMEL) + test_common.validate_cmp(sigtype, out, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_eq_2.py b/third_party/ascend/examples/pytest_ut/test_eq_2.py new file mode 100644 index 000000000..e8a2123d0 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_eq_2.py @@ -0,0 +1,41 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_eq(x0, x1): + return x0 == x1 + + +@triton.jit +def triton_eq(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index, None) + tmp1 = tl.load(in_ptr1 + x_index, None) + tmp2 = tmp0 == tmp1 + tl.store(out_ptr0 + x_index, tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_eq(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_eq(x0, x1).to(eval('torch.' + dtype)) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_eq[ncore, 1, 1](x0, x1, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_exp.py b/third_party/ascend/examples/pytest_ut/test_exp.py new file mode 100644 index 000000000..32576f634 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_exp.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + +def torch_pointwise(x0): + res = torch.exp(x0) + return res + + +@triton.jit +def triton_exp(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp2 = tl.exp(tmp0) + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + triton_exp[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_exp2.py b/third_party/ascend/examples/pytest_ut/test_exp2.py new file mode 100644 index 000000000..fc9ccc57b --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_exp2.py @@ -0,0 +1,40 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_exp2(x0): + res = torch.pow(2, x0, out=None) + return res + + +@triton.jit +def triton_exp2(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index, None) + tmp1 = tl.exp2(tmp0) + tl.store(out_ptr0 + x_index, tmp1, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_exp2(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_exp2(x0) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_exp2[ncore, 1, 1](x0, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_exp_.py b/third_party/ascend/examples/pytest_ut/test_exp_.py new file mode 100644 index 000000000..b71123145 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_exp_.py @@ -0,0 +1,79 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + + +def standard_unary(x0, dtype): + res = torch.exp(x0) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + ret = tl.math.exp(x) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + # (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + print(out) + + test_common.validate_cmp(sigtype, out, ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_expand_dims.py b/third_party/ascend/examples/pytest_ut/test_expand_dims.py new file mode 100644 index 000000000..4dc7ec66f --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_expand_dims.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest + +@triton.jit +def fn_npu_(output_ptr, x_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + X = tl.load(x_ptr+idx) + + ret = tl.expand_dims(X,2) + + oidx=xidx[:,None,None,None]*YB*ZB+yidx[None,:,None,None]*ZB+tl.arange(0,1)[None,None,:,None]+zidx[None,None,None,:] + + tl.store(output_ptr+oidx,ret) + +paras = [ + ('*fp32',eval('torch.float32'),2,256,16), + ('*fp32',eval('torch.float32'),8,8,4), + ('*fp16',eval('torch.float16'),2,256,16), + ('*fp16',eval('torch.float16'),8,8,4), + ('*i8',eval('torch.int8'),2,256,16), + ('*i8',eval('torch.int8'),8,8,4), +] + +@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', paras) +def test_npu(para_type,data_type,XB,YB,ZB): + + x = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=data_type).npu() + a = x.unsqueeze(2) + + print(f"shape = {x.shape}") + print(x.dtype) + print(a[0,0:16,0,0]) + + output = torch.randint(1, (XB,YB,1,ZB), dtype=data_type).npu() + + print(f"output.dtype={output.dtype}") + + fn_npu_[1,1,1](output,x, XB=XB, YB=YB, ZB=ZB, debug=True) + print(output[0,0:16,0,0]) + + torch.testing.assert_close(output,a) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_extract_slice.py b/third_party/ascend/examples/pytest_ut/test_extract_slice.py new file mode 100644 index 000000000..24380ef14 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_extract_slice.py @@ -0,0 +1,36 @@ +import torch +import torch_npu + +import triton +import triton.language as tl + +import pytest + +@triton.jit +def triton_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + out_sub = tl.extract_slice(output, [block_start], [32], [1]) + out_idx = block_start + tl.arange(0, 32) + out_msk = out_idx < n_elements + tl.store(output_ptr + out_idx, out_sub, mask=out_msk) + +def triton_func(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + triton_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + return output + +def test_extract_slice(): + size = 1024 + x = torch.rand(size, device='npu') + y = torch.rand(size, device='npu') + torch_ref = x + y + triton_cal = triton_func(x, y) + torch.testing.assert_close(triton_cal[:32], torch_ref[:32]) diff --git a/third_party/ascend/examples/pytest_ut/test_fdiv.py b/third_party/ascend/examples/pytest_ut/test_fdiv.py new file mode 100644 index 000000000..e79219f29 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_fdiv.py @@ -0,0 +1,44 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_fdiv(x0, x1): + res = x0 / x1 + return res + + +@triton.jit +def triton_fdiv(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + x_index = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + tmp0 = tl.load(in_ptr0 + x_index) + tmp1 = tl.load(in_ptr1 + x_index) + tmp2 = tl.fdiv(tmp0, tmp1) + tl.store(out_ptr0 + x_index, tmp2) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_fdiv(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_tmp = test_common.generate_tensor(shape, dtype) + y0 = y_tmp.masked_fill(y_tmp == 0, 1) + y0 = y0.npu() + + # torch结果 + y_ref = torch_fdiv(x0, y0).to(eval('torch.' + dtype)) + # triton结果 + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_fdiv[ncore, 1, 1](x0, y0, y_cal, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_flip.py b/third_party/ascend/examples/pytest_ut/test_flip.py new file mode 100644 index 000000000..9f55dfecc --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_flip.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common +from triton.runtime.libentry import libentry + + +@pytest.mark.parametrize( + "para_type,data_type,shape", + [ + ["float32", torch.float32, (3, 11, 17)], + ["float16", torch.float16, (3, 11, 17)], + ["int8", torch.int8, (3, 11, 17)], + ], +) +def test_flip(para_type, data_type, shape): + + def torch_func(x): + return torch.flip(x, dims=(2,)) + + @libentry() + @triton.jit + def triton_kernel( + output_ptr0, in_ptr0, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr + ): + xidx = tl.arange(0, XB) + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + idx = ( + xidx[:, None, None] * YB * ZB + + yidx[None, :, None] * ZB + + zidx[None, None, :] + ) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.flip(tmp0, 2) + tl.store(output_ptr0 + idx, tmp1) + + def triton_func(x): + XB, YB, ZB = shape + y = torch.empty_like(x) + triton_kernel[1, 1, 1](y, x, XB, YB, ZB) + return y + + x = torch.randint(low=-128, high=128, size=shape, dtype=data_type).npu() + torch_ref = torch_func(x) + triton_cal = triton_func(x) + test_common.validate_cmp(para_type, torch_ref, triton_cal) diff --git a/third_party/ascend/examples/pytest_ut/test_floor.py b/third_party/ascend/examples/pytest_ut/test_floor.py new file mode 100644 index 000000000..f5f06e291 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_floor.py @@ -0,0 +1,40 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common + + +def torch_floor(x0, x1): + res = x0 + torch.floor(x1) + return res + + +@triton.jit +def triton_floor(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + x_index = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = x_index < xnumel + tmp0 = tl.load(in_ptr0 + x_index, xmask) + tmp1 = tl.load(in_ptr1 + x_index, xmask) + tmp2 = tmp0 + tl.floor(tmp1) + tl.store(out_ptr0 + x_index, tmp2, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_floor(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + y_ref = torch_floor(x0, x1) + # triton结果 + y_cal = test_common.generate_tensor(shape, dtype).npu() + triton_floor[ncore, 1, 1](x0, x1, y_cal, x0.numel(), xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_floordiv.py b/third_party/ascend/examples/pytest_ut/test_floordiv.py new file mode 100644 index 000000000..43d06fbd8 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_floordiv.py @@ -0,0 +1,51 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +def torch_func(x0, x1): + res = x0 // x1 + return res + + +@triton.jit +def triton_kernel(out_ptr0, in_ptr0, in_ptr1, N: tl.constexpr): + idx = tl.arange(0, N) + x = tl.load(in_ptr0 + idx) + y = tl.load(in_ptr1 + idx) + ret = x // y + tl.store(out_ptr0 + idx, ret) + + +def triton_func(x0, x1, N): + out = torch.empty_like(x0) + triton_kernel[1, 1, 1](out, x0, x1, N) + return out + + +types = [ + "int32", +] + +shapes = [ + 3, + 32, + 37, + 256, + 781, +] + +@pytest.mark.parametrize("sigtype", types) +@pytest.mark.parametrize("N", shapes) +def test_floordiv(sigtype, N): + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + x1 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + x1 = x1.masked_fill(x1 == 0, 1) + + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, N) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + diff --git a/third_party/ascend/examples/pytest_ut/test_full.py b/third_party/ascend/examples/pytest_ut/test_full.py new file mode 100644 index 000000000..18a731cd2 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_full.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu +import pytest + +@triton.jit +def fn_npu_f32(output_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + + ret = tl.full((XB,YB,ZB),value = 100,dtype = tl.float32) + + oidx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + tl.store(output_ptr+oidx,ret) + +@triton.jit +def fn_npu_f16(output_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + + ret = tl.full((XB,YB,ZB),value = 100,dtype = tl.float16) + + oidx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + tl.store(output_ptr+oidx,ret) + +@triton.jit +def fn_npu_i8(output_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + + ret = tl.full((XB,YB,ZB),value = 100,dtype = tl.int8) + + oidx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + tl.store(output_ptr+oidx,ret) + +testlist = [ + (fn_npu_f32,'float32',torch.float32,2,256,16), + (fn_npu_f32,'float32',torch.float32,8,8,4), + + (fn_npu_f16,'float16',torch.float16,2,256,16), + (fn_npu_f16,'float16',torch.float16,8,8,4), + + (fn_npu_i8,'int8',torch.int8,2,256,16), + (fn_npu_i8,'int8',torch.int8,8,8,4), +] + +@pytest.mark.parametrize('testfunc, sigtype, dtype, XB, YB, ZB',testlist) +def test_npu(testfunc, sigtype, dtype, XB, YB, ZB): + + x = torch.full((XB,YB,ZB),100,dtype=dtype).npu() + + print(f"shape = {x.shape}") + print(x.dtype) + print(x[0,0:16,0]) + + output = torch.randint(1, (XB,YB,ZB), dtype=dtype).npu() + + print(f"output.dtype={output.dtype}") + + testfunc[1,1,1](output,XB,YB,ZB,debug=True) + print(output[0,0:16,0]) + + test_common.validate_cmp(sigtype,output,x) diff --git a/third_party/ascend/examples/pytest_ut/test_fusedattention.py b/third_party/ascend/examples/pytest_ut/test_fusedattention.py new file mode 100644 index 000000000..b57defbf6 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_fusedattention.py @@ -0,0 +1,347 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Credits: OpenAI kernel team + +Extra Credits: + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +DEVICE = "npu" + + +@triton.jit +def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, query vector + K_block_ptr, V_block_ptr, # Key and value block pointers for current stage + start_m, qk_scale, # Starting position of current query block, qk scale factor + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr, fp8_v: tl.constexpr): + # range of values handled by this stage + # causal = true + # stage = 1 + # 因果注意力,顾名思义,它在计算时会限制信息的流动,只允许模型看到当前位置及之前的位置 + # 的信息。也就是说,当前位置的输出只能依赖于该位置及其之前的输入,而不能访问当前位置 + # 之后的信息。因果注意力保证了数据的顺序性,避免了“未来信息”的泄露。 + # 但是后面的逻辑也会触发 + if STAGE == 1: + # Stage 1: process all tokens before the query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + # Stage 2: process the current query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) # Align starting position + # causal = False (no need for masking) + else: + lo, hi = 0, N_CTX # Process the entire context + + # Adjust K and V block pointers to the starting position `lo` + K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) # K is [HEAD_DIM, N_CTX], shift along the second dim by lo + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # V is [N_CTX, HEAD_DIM], shift along the first dim by lo + + # Index mapping for the accumulator , used for slicing when HEAD_DIM >= 256 + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + + # Iterate over all k, v blocks in the current stage and accumulate the output + for start_n in range(lo, hi, BLOCK_N): # Process BLOCK_N columns at a time + start_n = tl.multiple_of(start_n, BLOCK_N) # Align column start position + # -- Compute qk ---- + k = tl.load(K_block_ptr) + # Modify K + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + # Apply causal mask for STAGE 2 + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) # Construct upper triangular mask + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) # Set invalid positions to -∞ + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Update m_ij = max(m_i, max(qk)) + qk -= m_ij[:, None] # Subtract max for softmax stability + else: + qk = qk * qk_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max + qk = qk - m_ij[:, None] # Stabilize + + # Softmax weights p = exp(qk) + p = tl.math.exp(qk) + + # Convert softmax weight type depending on FP8 usage + if fp8_v: + p_cast = p.to(tl.float8e5) # Convert to FP8 format (save memory) + else: + p_cast = p.to(k.dtype) + + v = tl.load(V_block_ptr) # Load corresponding V block + pv = tl.dot(p_cast, v) + l_ij = tl.sum(p, 1) # Softmax denominator (sum of each row) + # -- Update m_i and l_i + alpha = tl.math.exp(m_i - m_ij) # Update factor: exp difference between old and new max + l_i = l_i * alpha + l_ij # Update softmax denominator + # -- Update output accumulator -- + if HEAD_DIM < 256: + acc_ptr = acc_ptr * alpha[:, None] + acc_ptr = tl.dot(p_cast, v, acc_ptr) + else: + # 1. Load current slice of accumulator + acc = tl.load(acc_ptr + block2d_acc) + # 2. Update in slices (split by 1/4 of BLOCK_M to avoid ub overflow) + for i in range(4): + # Calculate start/end rows for current slice + offset = i * (BLOCK_M // 4) + # Extract slice data + acc_i = tl.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + alpha_i = tl.extract_slice(alpha, [offset], [BLOCK_M // 4], [1]) + pv_i = tl.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # Incrementally update slice: acc = acc * alpha + pv + acc_i = acc_i * alpha_i[:, None] + pv_i + # Write updated slice back to accumulator + acc = tl.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # 3. updated accumulator + tl.store(acc_ptr + block2d_acc, acc) + + m_i = m_ij # Update current block max + # Advance V and K block pointers to next BLOCK_N range + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + # Return accumulated output acc_ptr, softmax denominator l_i, and max value m_i + return acc_ptr, l_i, m_i + + +@triton.jit +def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, + stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, stride_qk: tl.constexpr, # + stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, stride_kk: tl.constexpr, # + stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, stride_vk: tl.constexpr, # + stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, stride_on: tl.constexpr, # + Z: tl.constexpr, H: tl.constexpr, + N_CTX: tl.constexpr, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr # + ): + # Total number of blocks in sequence dimension (M) + NUM_BLOCKS_M = N_CTX // BLOCK_M + # Total tasks = number of sequence blocks × batch size (Z) × number of attention heads (H) + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + + # Current M-dimension block index + pid = tl.program_id(0) + + for block_idx in range(pid, NUM_BLOCKS, 20): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + # Create block pointers for Q, K, V, Output + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # Initialize offsets + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + + # Initialize accumulator + if HEAD_DIM < 256: + acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + else: + acc_offset = ( + off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + + task_m_idx * BLOCK_M * HEAD_DIM + ) + acc_ptr = acc + acc_offset + + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + + m_i += tl.math.log(l_i) + if HEAD_DIM < 256: + accumulator = acc_ptr / l_i[:, None] + else: + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + accumulator = tl.load(acc_ptr + block2d_acc) + accumulator = accumulator / l_i[:, None] + + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, BM, BN): + """ + Forward computation interface: + Args: + ctx: Context object + q: Query tensor (Q), shape [Z, H, N_CTX, HEAD_DIM] + k: Key tensor (K), shape [Z, H, N_CTX, HEAD_DIM] + v: Value tensor (V), shape [Z, H, N_CTX, HEAD_DIM] + causal: Whether to enable causal attention + sm_scale: Scaling factor for QK product + BM: Q block size (BLOCK_M) + BN: K/V block size (BLOCK_N) + Returns: + o: Attention output tensor, shape [Z, H, N_CTX, HEAD_DIM] + """ + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + o = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + + + # Number of NPU cores (adjust based on hardware) + num_cores = 20 + acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[(num_cores,)]( + q, k, v, M, o, acc, sm_scale, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + STAGE=stage, + **extra_kern_args) + + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + +attention = _attention.apply + + +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN", [ + (1, 1, 128, 128, False, torch.float16, 32, 128), + (1, 1, 128, 128, False, torch.bfloat16, 64, 128), + (1, 2, 128, 128, False, torch.float16, 32, 128), + (1, 2, 256, 256, False, torch.bfloat16, 32, 256), + (2, 2, 128, 256, False, torch.float16, 64, 128), + (4, 32, 32, 64, False, torch.bfloat16, 32, 32), + (4, 32, 64, 64, False, torch.float16, 32, 64), + (4, 32, 1024, 64, False, torch.bfloat16, 64, 64), + (4, 32, 4096, 64, False, torch.float16, 64, 64), +]) +def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN): + # 过滤非整切案例, N_CTX 需整除 BM 和 BN, 且 HEAD_DIM 需整除 16 + if N_CTX % BM != 0 or N_CTX % BN != 0 or HEAD_DIM % 16 != 0: + pytest.skip("Skipping non-divisible case") + + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + + sm_scale = 0.5 + + tri_out = attention(q, k, v, causal, sm_scale, BM, BN) + ref_out = torch_npu.npu_fusion_attention( + q, k, v, H, + padding_mask=None, + atten_mask=None, + scale=sm_scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] + + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2, equal_nan=True) + + +if __name__ == "__main__": + test_op(1, 1, 128, 128, causal=False, dtype=torch.float16, BM=32, BN=128) + test_op(1, 1, 128, 128, causal=False, dtype=torch.bfloat16, BM=64, BN=128) + test_op(1, 2, 128, 128, causal=False, dtype=torch.float16, BM=32, BN=128) + test_op(1, 2, 256, 256, causal=False, dtype=torch.bfloat16, BM=32, BN=256) + test_op(2, 2, 128, 256, causal=False, dtype=torch.float16, BM=64, BN=128) + test_op(4, 32, 32, 64, causal=False, dtype=torch.bfloat16, BM=32, BN=32) + test_op(4, 32, 64, 64, causal=False, dtype=torch.float16, BM=32, BN=64) + test_op(4, 32, 1024, 64, causal=False, dtype=torch.bfloat16, BM=64, BN=64) + test_op(4, 32, 4096, 64, causal=False, dtype=torch.float16, BM=64, BN=64) diff --git a/third_party/ascend/examples/pytest_ut/test_gather.py b/third_party/ascend/examples/pytest_ut/test_gather.py new file mode 100644 index 000000000..18ac8b3ca --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_gather.py @@ -0,0 +1,116 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import triton.language.extra.ascend.libdevice as libdevice +import numpy as np +import test_common +import pytest + +@pytest.mark.skip(reason="waiting for the compiler to support.") +@pytest.mark.parametrize("src_shape, indices_shape, axis", [ + ([4, 4], [8, 4], 0), + ([4, 64], [4, 32], 1), + ([128, 64], [128, 128], 1), +]) +def test_gather(src_shape, indices_shape, axis): + + @triton.jit + def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + gather_kernel[(1, )](src, indices, output, axis, + src.shape[0], src.shape[1], + src.stride(0), src.stride(1), + indices.shape[0], indices.shape[1], + indices.stride(0), indices.stride(1), + output.shape[0], output.shape[1], + output.stride(0), output.stride(1)) + return output + + DEV = "npu" + src = torch.randn(src_shape, device=DEV) + indices = torch.randint(0, src.shape[axis], indices_shape, device=DEV) + ref = torch.gather(src, axis, indices) + result = triton_gather(src, axis, indices) + torch.testing.assert_close(result, ref, rtol=0, atol=0) + +@pytest.mark.parametrize('param_list', + [ + ['float16', (11, 12, 256, 512), 48], + ]) +def test_gather_flip(param_list): + + def torch_func(inp, idx): + return torch.gather(input=inp, dim=-1, index=idx) + + @triton.jit + def triton_kernel(dst_ptr, src_ptr, idx_ptr, + XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr, + R0_BLOCK: tl.constexpr, R1_BLOCK: tl.constexpr): + pid = tl.program_id(0) + poff = pid * XBLOCK + x0_idx_base = 0 + r1_idx = tl.arange(0, R1_BLOCK) + loop0 = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for xsub_id in tl.range(loop0): + x0_idx = poff + xsub_id * XBLOCK_SUB + x0_idx_base + idx_idx = idx_ptr + x0_idx * R1_BLOCK + r1_idx + idx_blk = tl.load(idx_idx) + idx_min = tl.min(idx_blk, axis=0) + src_idx = src_ptr + x0_idx * R0_BLOCK + idx_min + r1_idx + src_blk = tl.load(src_idx) + fliped_blk = libdevice.flip(src_blk, 0) + dst_idx = dst_ptr + x0_idx * R1_BLOCK + r1_idx + tl.store(dst_idx, fliped_blk) + + def triton_func(p2c_out, p2c_att, p2c_pos, ncore): + nrows = p2c_att.shape[0] * p2c_att.shape[1] * p2c_att.shape[2] + xs = nrows // ncore + assert(xs * ncore == nrows) + xss = 1 # must be 1 + r0s = p2c_att.shape[3] + r1s = p2c_att.shape[2] + triton_kernel[ncore, 1, 1](p2c_out, p2c_att, p2c_pos, + XBLOCK=xs, XBLOCK_SUB=xss, + R0_BLOCK=r0s, R1_BLOCK=r1s) + return p2c_out + + dtype, shape, ncore = param_list + M0, M1, N0, N1 = shape + r0 = torch.arange(N0) + c0 = torch.arange(N0) + p2c_pos = r0[:, None] - c0[None, :] + N0-1 + p2c_pos = p2c_pos.broadcast_to((M0, M1, N0, N0)) + p2c_pos = p2c_pos.npu() + if (p2c_pos.dtype == torch.int64): + p2c_pos = p2c_pos.to(torch.int32) + assert(np.all(np.diff(p2c_pos.cpu()) == -1)) + p2c_att = test_common.generate_tensor(shape, dtype).npu() + p2c_out = test_common.generate_tensor(p2c_pos.shape, dtype).npu() + + p2c_ref = torch_func(p2c_att, p2c_pos) + triton_func(p2c_out, p2c_att, p2c_pos, ncore) + test_common.validate_cmp(dtype, p2c_out, p2c_ref) + +if __name__ == "__main__": + param_list = ['float16', (11, 12, 256, 512), 48] + test_gather_flip(param_list) + print("success: test_gather_flip") + test_gather([4, 64], [4, 32], 1) + print("success: test_gather") \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_ge.py b/third_party/ascend/examples/pytest_ut/test_ge.py new file mode 100644 index 000000000..17ab128bb --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_ge.py @@ -0,0 +1,56 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +def standard_binary(x0, y0): + res = x0 >= y0 + return res + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x >= y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + (torch.int8, 'int8'), + (torch.int16, 'int16'), + (torch.int32, 'int32'), + (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + ans = standard_binary(x0, y0) + out = torch.zeros((N,), dtype=torch.bool).npu() + triton_elementwise_binary[1, 1, 1](x0, y0, out, N, NUMEL) + test_common.validate_cmp(sigtype, out, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_ge_2.py b/third_party/ascend/examples/pytest_ut/test_ge_2.py new file mode 100644 index 000000000..e6054b697 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_ge_2.py @@ -0,0 +1,44 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_ge(x0, x1, dtype): + res = torch.where(torch.ge(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + dtype)) + return res + + +@triton.jit +def triton_ge(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 >= tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_ge(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_ge(x0, x1, dtype) + # triton结果 + triton_res = torch.empty_like(x0) + triton_ge[ncore, 1, 1](x0, x1, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_gelu.py b/third_party/ascend/examples/pytest_ut/test_gelu.py new file mode 100644 index 000000000..be629e695 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_gelu.py @@ -0,0 +1,78 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +def standard_unary(x0, dtype): + res = x0 * 0.5 * (1.0 + torch.erf(x0 / torch.sqrt(torch.tensor(2.0)))) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + ret = x * 0.5 * (1.0 + tl.erf(x / tl.sqrt(2.0))) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + # (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + # (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + print(out) + + test_common.validate_cmp(sigtype, out, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_gt.py b/third_party/ascend/examples/pytest_ut/test_gt.py new file mode 100644 index 000000000..bac90c645 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_gt.py @@ -0,0 +1,42 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_gt(x0, x1, dtype): + res = torch.where(torch.gt(x0, x1), torch.ones_like(x0), torch.zeros_like(x0)).to(eval('torch.' + dtype)) + return res + + +@triton.jit +def triton_gt(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 > tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_gt(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_gt(x0, x1, dtype) + # triton结果 + triton_res = torch.empty_like(x0) + triton_gt[ncore, 1, 1](x0, x1, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_hd_permute.py b/third_party/ascend/examples/pytest_ut/test_hd_permute.py new file mode 100644 index 000000000..d185a0d15 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_hd_permute.py @@ -0,0 +1,33 @@ +import triton +import triton.language as tl +import torch +import torch_npu + +X_SIZE: tl.constexpr = 4 +Y_SIZE: tl.constexpr = 64 +Z_SIZE: tl.constexpr = 32 +NUMEL = X_SIZE * Y_SIZE * Z_SIZE + + +def torch_permute(x): + return x.reshape((X_SIZE, Y_SIZE, Z_SIZE)).permute(1, 0, 2).reshape((X_SIZE * Y_SIZE * Z_SIZE)) + + +@triton.jit +def triton_permute(output_ptr, input_ptr): + x_index = tl.arange(0, X_SIZE * Y_SIZE * Z_SIZE) + input_local = tl.load(input_ptr + x_index) + output_local = input_local.reshape((X_SIZE, Y_SIZE, Z_SIZE)).permute(1, 0, 2).reshape((X_SIZE * Y_SIZE * Z_SIZE)) + tl.store(output_ptr + x_index, output_local) + + +def test_hd_permute(): + # 生成数据 + x = torch.randn(NUMEL).npu() + # torch结果 + torch_res = torch_permute(x) + # triton结果 + triton_res = torch.randn(torch_res.shape, dtype=torch_res.dtype).npu() + triton_permute[1, 1, 1](triton_res, x) + # 比较结果 + torch.testing.assert_close(triton_res, torch_res, rtol=1e-3, atol=1e-3) diff --git a/third_party/ascend/examples/pytest_ut/test_if_tensor.py b/third_party/ascend/examples/pytest_ut/test_if_tensor.py new file mode 100644 index 000000000..51e8706e3 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_if_tensor.py @@ -0,0 +1,35 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def if_tensor_kernel( + kv_start_idx, # tensor + output_ptr, +): + pid = tl.program_id(0) + if kv_start_idx: + value = tl.load(kv_start_idx + pid) + tl.store(output_ptr + pid, value) + + +# 测试函数 +def test_kernel(): + n = 8 + device = 'npu' + + kv_start_idx = torch.arange(n, dtype=torch.float32, device=device) + output1 = torch.zeros(n, dtype=torch.float32, device=device) + if_tensor_kernel[(n,)]( + kv_start_idx, output1, + ) + + expected = torch.arange(n, dtype=torch.float32, device=device) + assert torch.allclose(output1, expected), f"Output {output1} != Expected {expected}" + print(f"RESULT: output1 = {output1}") + print("✅ Test passed!") + + +if __name__ == "__main__": + test_kernel() diff --git a/third_party/ascend/examples/pytest_ut/test_index_select.py b/third_party/ascend/examples/pytest_ut/test_index_select.py new file mode 100644 index 000000000..0af28f176 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_index_select.py @@ -0,0 +1,86 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + +# shape: same as load, not scalar +# dim [0,len(shape)-1] +# indice_shape: 1D tensor, +# dtype: same as load +# profiling: 0.6x AscendC + +# known issue: runtime error when src_tensor.stride(dim) is not 32B aligned +@pytest.mark.parametrize("src_shape, dim, indice_shape, dtype", [ + ((500000, 240), 0, (375144,), "float32"), # standard 1984us +- 5us fp32 48core + # ((500000, 37), 0, (324344,), "float32"), # standard 266us +- 10us fp32 48core + ((3200, 16), 0, (1971940,), "float32"), # standard 1104us +- fp32 + # ((1971940, 2), 1, (1,), "int32"), # 4678us +- 129us INT32 + # ((3200, 1, 37), 0, (1022226,), "float32"), # 1965us +- 65us fp32 + # ((323357, 37), 0, (1022226,), "float32"), # 3500us +- 50us fp32 + ((480000, 16), 0, (3943880,), "float32"), # 3636 +- 42us fp32 + ((480000, 32), 0, (3943880,), "float32"), # 3678 +- 5us fp32 + ((480000, 64), 0, (3943880,), "float32"), # 5392 +- 20us fp32 + ((378103, 240), 0, (1992337,), "float32"), # 10000us fp32 + ((374035, 240), 0, (1971940,), "float32"), # 10100us fp32 + ((2000000, 32), 0, (270244,), "float32"), # 155us fp32s + ((1000000, 8), 0, (21329,), "float32"), # 44us fp32 + ((270098, 32), 0, (512,), "float32"), # 40us fp32 + ((20669, 8), 0, (2048,), "float32"), # 31us fp32 + ((10, 16), 0, (1024,), "float32"), # 27us fp32 + ((270610, 32), 0, (2928000,), "float32"), # 553us fp32 + ((22717, 8), 0, (278400,), "float32"), # 76us fp32 + ((1034, 16), 0, (48000,), "float32"), # 38us fp32 +]) +def test_index_select(src_shape, dim, indice_shape, dtype): + + def torch_func(x0, dim, indices): + res = torch.index_select(x0, dim, indices) + return res + + @triton.jit + def basic_index_select(in_ptr, indices_ptr, out_ptr, dim, + g_stride: tl.constexpr, indice_length: tl.constexpr, + g_block : tl.constexpr, g_block_sub: tl.constexpr, other_block:tl.constexpr): + g_begin=tl.program_id(0) * g_block + for goffs in range(0, g_block, g_block_sub): + g_idx=tl.arange(0, g_block_sub) + g_begin + goffs + g_mask = g_idx < indice_length + indices = tl.load(indices_ptr + g_idx, g_mask, other=0) + for other_offset in range(0, g_stride, other_block): + tmp_buf = tl.zeros((g_block_sub, other_block), in_ptr.dtype.element_ty) + other_idx = tl.arange(0, other_block) + other_offset + other_mask = other_idx < g_stride + for i in range(0, g_block_sub): + gather_offset = tl.get_element(indices, (i,)) * g_stride + val = tl.load(in_ptr + gather_offset + other_idx, other_mask) + tmp_buf = tl.insert_slice(tmp_buf, val[None,:], offsets=(i, 0), sizes=(1, other_block), strides=(g_stride, 1)) + tl.store(out_ptr + g_idx[:,None] * g_stride + other_idx[None,:], tmp_buf, g_mask[:,None] & other_mask[None,:]) + + def triton_func(x0, dim, indices): + sz = list(x0.shape) + sz[dim]=len(indices) + out = torch.empty(tuple(sz), dtype=x0.dtype).npu() + g_stride = x0.stride(dim) + indice_length=indices.numel() + num_vec_core=40 + g_block = (indice_length - 1) // num_vec_core + 1 + enable_multi_buffer=True + available_ub_space = (125 * 1024) // (x0.element_size() * (2 if enable_multi_buffer else 1)) + if g_stride * 2 < available_ub_space: + other_block = g_stride + g_block_sub = available_ub_space // other_block + else: + other_block = available_ub_space + g_block_sub = 1 + basic_index_select[num_vec_core, 1, 1](x0, indices, out, dim, g_stride = g_stride, indice_length=indice_length, + g_block = g_block, g_block_sub = g_block_sub, other_block = other_block, multibuffer=False) + return out + + x0 = test_common.generate_tensor(shape=src_shape, dtype=dtype).npu() + indices = torch.randint(0, src_shape[dim], size=indice_shape, dtype=torch.int32).npu() + + torch_ref = torch_func(x0, dim, indices) + triton_cal = triton_func(x0, dim, indices) + assert torch.equal(torch_ref, triton_cal) diff --git a/third_party/ascend/examples/pytest_ut/test_insert_slice.py b/third_party/ascend/examples/pytest_ut/test_insert_slice.py new file mode 100644 index 000000000..ce8c0647c --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_insert_slice.py @@ -0,0 +1,40 @@ +import torch +import torch_npu + +import triton +import triton.language as tl + +import pytest + +@triton.jit +def triton_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, SLICE_OFFSET: tl.constexpr, SLICE_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + x_sub = tl.extract_slice(x, [block_start+SLICE_OFFSET], [SLICE_SIZE], [1]) + y_sub = tl.extract_slice(y, [block_start+SLICE_OFFSET], [SLICE_SIZE], [1]) + output_sub = x_sub + y_sub + output = tl.load(output_ptr + offsets, mask=mask) + output = tl.insert_slice(output, output_sub, [block_start+SLICE_OFFSET], [SLICE_SIZE], [1]) + tl.store(output_ptr + offsets, output, mask=mask) + +def triton_func(x: torch.Tensor, y: torch.Tensor, slice_offset: int, slice_size: int): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + triton_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, SLICE_OFFSET=0, SLICE_SIZE=32) + return output + +def test_extract_slice(): + size = 1024 + slice_offset = 0 + slice_size = 32 + x = torch.rand(size, device='npu') + y = torch.rand(size, device='npu') + torch_ref = x + y + triton_cal = triton_func(x, y, slice_offset, slice_size) + torch.testing.assert_close(triton_cal[slice_offset:slice_offset+slice_size], + torch_ref[slice_offset:slice_offset+slice_size]) diff --git a/third_party/ascend/examples/pytest_ut/test_interleave.py b/third_party/ascend/examples/pytest_ut/test_interleave.py new file mode 100644 index 000000000..0145bae97 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_interleave.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + +@triton.jit +def fn_npu_(output_ptr, x_ptr,y_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + X = tl.load(x_ptr+idx) + Y = tl.load(y_ptr+idx) + + ret = tl.interleave(X,Y) + + oidx=xidx[:,None,None]*YB*ZB*2+yidx[None,:,None]*ZB*2+tl.arange(0,2*ZB)[None,None,:] + + tl.store(output_ptr+oidx,ret) + + + +@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', + [ + ['float32',torch.float32,2,64,16], + ['float32',torch.float32,8,8,4], + ['float16',torch.float16,2,64,16], + ['float16',torch.float16,8,8,4], + ['int8',torch.int8,2,64,32], + ['int8',torch.int8,8,8,4], + ] + ) +def test_interleave(para_type,data_type,XB,YB,ZB): + + x = torch.full((XB,YB,ZB),100,dtype=data_type).npu() + y = torch.full((XB,YB,ZB),30,dtype=data_type).npu() + + print(f"shape = {x.shape}") + print(x.dtype) + + output = torch.randint(1, (XB,YB,ZB*2), dtype=data_type).npu() + output1 = output + print(f"output.dtype={output.dtype}") + + ans = torch.stack((x,y),dim=-1).reshape(XB,YB,ZB*2) + print(ans) + print(ans.shape) + + fn_npu_[1,1,1](output,x,y,XB,YB,ZB) + print(output) + + test_common.validate_cmp(para_type, ans, output) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_invert.py b/third_party/ascend/examples/pytest_ut/test_invert.py new file mode 100644 index 000000000..0bffa5eeb --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_invert.py @@ -0,0 +1,44 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_invert(x0): + res = ~(x0) + return res + + +@triton.jit +def triton_invert(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp2 = ~tmp0 + tl.store(out_ptr0 + (xindex), tmp2, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['int8', (2, 4096, 8), 2, 32768, 1024], + ['int16', (2, 4096, 8), 2, 32768, 1024], + ['int32', (2, 4096, 8), 2, 32768, 1024], + ['int64', (2, 4096, 8), 2, 32768, 1024], + ['bool', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_invert(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_invert(x0) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_invert[ncore, 1, 1](x0, triton_res, x0.numel(), xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_isfinited.py b/third_party/ascend/examples/pytest_ut/test_isfinited.py new file mode 100644 index 000000000..e287fd367 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_isfinited.py @@ -0,0 +1,119 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + +types = [ + "float32", + "float16", + "bfloat16", +] + +shapes = [ + # 3, + # 32, + 37, + # 256, + # 781, +] + + +@pytest.mark.parametrize("sigtype", types) +@pytest.mark.parametrize("N", shapes) +@pytest.mark.parametrize("val", ['1.0', 'nan', 'inf', '-inf']) +def test_isfinited(sigtype, N, val): + + def torch_func(x0): + res = torch.isfinite(x0) + return res + + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = tl.arange(0, N) + x0 = tl.load(in_ptr0 + idx) + ret = tl.math.isfinited(x0) + tl.store(out_ptr0 + idx, ret) + + + def triton_func(x0, N): + out = torch.zeros(x0.shape, dtype=torch.bool).npu() + triton_kernel[1, 1, 1](out, x0, N) + return out + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + x0[1] = float(val) + + torch_ref = torch_func(x0) + triton_cal = triton_func(x0, N) + + test_common.validate_cmp("bool", triton_cal, torch_ref) + + +@pytest.mark.parametrize("N", shapes) +@pytest.mark.parametrize("val", ['1.0', 'nan', 'inf', '-inf']) +def test_finitef(N, val): + + def torch_func(x0): + res = torch.isfinite(x0) + return res + + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = tl.arange(0, N) + x0 = tl.load(in_ptr0 + idx) + ret = tl.math.finitef(x0) + tl.store(out_ptr0 + idx, ret) + + + def triton_func(x0, N): + out = torch.zeros(x0.shape, dtype=torch.bool).npu() + triton_kernel[1, 1, 1](out, x0, N) + return out + + x0 = test_common.generate_tensor(shape=(N,), dtype='float32').npu() + x0[1] = float(val) + + torch_ref = torch_func(x0) + triton_cal = triton_func(x0, N) + test_common.validate_cmp("bool", triton_cal, torch_ref) + + +invalid_types = [ + "int32", +] + + +@pytest.mark.parametrize("sigtype", invalid_types) +@pytest.mark.parametrize("N", shapes) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "Expected dtype fp16/fp32/bf16, but got") +def test_isfinited_invalid_dtype(sigtype, N): + + def torch_func(x0): + res = torch.isfinite(x0) + return res + + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = tl.arange(0, N) + x0 = tl.load(in_ptr0 + idx) + ret = tl.math.isfinited(x0) + tl.store(out_ptr0 + idx, ret) + + + def triton_func(x0, N): + out = torch.zeros(x0.shape, dtype=torch.bool).npu() + triton_kernel[1, 1, 1](out, x0, N) + return out + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + x0[1] = float('nan') + + torch_ref = torch_func(x0) + triton_cal = triton_func(x0, N) + test_common.validate_cmp("bool", triton_cal, torch_ref) + assert triton_cal[1] == True diff --git a/third_party/ascend/examples/pytest_ut/test_isnan.py b/third_party/ascend/examples/pytest_ut/test_isnan.py new file mode 100644 index 000000000..cbab8c8f5 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_isnan.py @@ -0,0 +1,85 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + +types = [ + "float32", + "float16", + "bfloat16", +] + +shapes = [ + # 3, + # 32, + 37, + # 256, + # 781, +] + +@pytest.mark.parametrize("sigtype", types) +@pytest.mark.parametrize("N", shapes) +def test_isnan(sigtype, N): + + def torch_func(x0): + res = torch.isnan(x0) + return res + + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = tl.arange(0, N) + x0 = tl.load(in_ptr0 + idx) + ret = tl.extra.ascend.libdevice.isnan(x0) + tl.store(out_ptr0 + idx, ret) + + + def triton_func(x0, N): + out = torch.zeros(x0.shape, dtype=torch.bool).npu() + triton_kernel[1, 1, 1](out, x0, N) + return out + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + x0[1] = float('nan') + + torch_ref = torch_func(x0) + triton_cal = triton_func(x0, N) + test_common.validate_cmp("bool", triton_cal, torch_ref) + assert triton_cal[1] == True + +invalid_types = [ + "int32", +] + +@pytest.mark.parametrize("sigtype", invalid_types) +@pytest.mark.parametrize("N", shapes) +@test_common.raises_with_match(triton.compiler.errors.CompilationError, "input arg type does not match.") +def test_isnan_invalid_dtype(sigtype, N): + + def torch_func(x0): + res = torch.isnan(x0) + return res + + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = tl.arange(0, N) + x0 = tl.load(in_ptr0 + idx) + ret = tl.extra.ascend.libdevice.isnan(x0) + tl.store(out_ptr0 + idx, ret) + + + def triton_func(x0, N): + out = torch.zeros(x0.shape, dtype=torch.bool).npu() + triton_kernel[1, 1, 1](out, x0, N) + return out + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + x0[1] = float('nan') + + torch_ref = torch_func(x0) + triton_cal = triton_func(x0, N) + test_common.validate_cmp("bool", triton_cal, torch_ref) + assert triton_cal[1] == True diff --git a/third_party/ascend/examples/pytest_ut/test_join.py b/third_party/ascend/examples/pytest_ut/test_join.py new file mode 100644 index 000000000..4c9deb0f8 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_join.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + +@triton.jit +def fn_npu_(output_ptr, x_ptr,y_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + + idx=xidx[:,None]*YB+yidx[None,:] + + X = tl.load(x_ptr+idx) + Y = tl.load(y_ptr+idx) + + ret = tl.join(X,Y) + + oidx=xidx[:,None,None]*YB*2+yidx[None,:,None]*2+tl.arange(0,2)[None,None,:] + + tl.store(output_ptr+oidx,ret) + +@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', + [ + ['float32',torch.float32,4,64,4], + ['float32',torch.float32,8,8,4], + ['float16',torch.float16,4,64,4], + ['float16',torch.float16,8,8,4], + ['int8',torch.int8,4,128,4], + ['int8',torch.int8,8,8,4], + ] + ) +def test_join(para_type,data_type,XB,YB,ZB): + x = torch.full((XB,YB),100,dtype=data_type).npu() + y = torch.full((XB,YB),30,dtype=data_type).npu() + + ans = torch.stack((x,y),dim=-1) + print(ans) + + output = torch.randint(1, (XB,YB,2), dtype=data_type).npu() + fn_npu_[1,1,1](output,x,y,XB, YB, ZB, debug = True) + + print(output) + test_common.validate_cmp(para_type, ans, output) diff --git a/third_party/ascend/examples/pytest_ut/test_lanzcos.py b/third_party/ascend/examples/pytest_ut/test_lanzcos.py new file mode 100644 index 000000000..4fd54669d --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_lanzcos.py @@ -0,0 +1,253 @@ +import torch, torch_npu +import triton +import triton.language as tl +import numpy as np +import math +import pytest + + +@triton.jit +def lanczos_resize_kernel( + img_src_ptr, + img_dst_ptr, + img_coeffs_ptr, + src_rows, + src_cols, + dst_rows, + dst_cols, + R_H, + R_W, + C, + stride_in_h, + stride_in_w, + stride_in_c, + stride_out_h, + stride_out_w, + stride_out_c, + BLOCK_SIZE: tl.constexpr, +): + block_id_c = tl.program_id(0) + block_id_h = tl.program_id(1) + block_id_w = tl.program_id(2) + dest_h_offs = block_id_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + dest_w_offs = block_id_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + dest_offs = ( + block_id_c[None, None] * stride_out_c + + dest_h_offs[:, None] * stride_out_h + + dest_w_offs[None, :] * stride_out_w + ) + + RR_H = 1.0 / R_H + RR_W = 1.0 / R_W + + fy = (dest_h_offs + 0.5) * RR_H - 0.5 + sy = tl.floor(fy) + fx = (dest_w_offs + 0.5) * RR_W - 0.5 + sx = tl.floor(fx) + + idxY = tl.floor((fy - sy) * 24.999999).to(tl.int32) + idxX = tl.floor((fx - sx) * 24.999999).to(tl.int32) + tableIndex = idxY[:, None] * 25 + idxX[None, :] + res = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), tl.float32) + + for ii in range(4): + for jj in range(4): + src_offsets = ( + block_id_c[None, None] * stride_in_c + + (tl.clamp((sy + ii - 1), 0, src_rows - 1)).to(tl.int32)[:, None] + * stride_in_h + + (tl.clamp((sx + jj - 1), 0, src_cols - 1)).to(tl.int32)[None, :] + * stride_in_w + ) + src_val = tl.load(img_src_ptr + src_offsets) + coeffs_offs = tableIndex[:, :] * 16 + (ii * 4 + jj)[None, None] + coeffs = tl.load(img_coeffs_ptr + coeffs_offs) + res = res + src_val * coeffs + dst_mask = (dest_h_offs[:, None] < dst_rows) & (dest_w_offs[None, :] < dst_cols) + res = tl.clamp(res, 0.0, 1.0) + tl.store(img_dst_ptr + dest_offs, res, mask=dst_mask) + + +def lanczos_resize_triton(img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols): + N, C, src_rows, src_cols = img_src.shape + R_H = float(dst_rows) / src_rows + R_W = float(dst_cols) / src_cols + + stride_in_n, stride_in_c, stride_in_h, stride_in_w = img_src.stride() + stride_out_n, stride_out_c, stride_out_h, stride_out_w = img_dst.stride() + BLOCK_SIZE = 16 + grid = lambda meta: ( + C, + triton.cdiv(dst_rows, meta["BLOCK_SIZE"]), + triton.cdiv(dst_cols, meta["BLOCK_SIZE"]), + ) + lanczos_resize_kernel[grid]( + img_src, + img_dst, + c_lanczosCoeffs, + src_rows, + src_cols, + dst_rows, + dst_cols, + R_H, + R_W, + C, + stride_in_h, + stride_in_w, + stride_in_c, + stride_out_h, + stride_out_w, + stride_out_c, + BLOCK_SIZE=BLOCK_SIZE, + ) + return img_dst + + +def lanczos_resize_cpu(img_src, img_dst, img_coeffs, dst_rows, dst_cols): + N, C, src_rows, src_cols = img_src.shape + R_H = float(dst_rows) / src_rows + R_W = float(dst_cols) / src_cols + for i in range(dst_rows): + for j in range(dst_cols): + RR_H = 1.0 / R_H + RR_W = 1.0 / R_W + fy = (i + 0.5) * RR_H - 0.5 + sy = math.floor(fy) + fx = (j + 0.5) * RR_W - 0.5 + sx = math.floor(fx) + idxY = math.floor((fy - np.floor(fy)) * 24.999999) + idxX = math.floor((fx - np.floor(fx)) * 24.999999) + tableIndex = idxY * 25 + idxX + res = (0.0, 0.0, 0.0, 0.0) + for ii in range(4): + for jj in range(4): + idx_y = np.clip(sy + ii - 1, 0, src_rows - 1) + idx_x = np.clip(sx + jj - 1, 0, src_cols - 1) + src_val = img_src[0, :, idx_y, idx_x] + coeffs_offs = tableIndex * 16 + (ii * 4 + jj) + coeffs = img_coeffs[coeffs_offs] + res = res + src_val * coeffs + + img_dst[0, :, i, j] = np.clip(res, 0.0, 1.0) + + +@pytest.mark.parametrize("shapes", [[360, 640, 140, 280],]) +def test_lanzcos(shapes): + c_lanczosCoeffs = torch.randn(10000, dtype=torch.float32, device="npu") / 4.0 + src_rows, src_cols, dst_rows, dst_cols = shapes + img_src = torch.randn(1, 4, src_rows, src_cols, dtype=torch.float32, device="npu") + img_dst = torch.zeros( + (1, img_src.shape[1], dst_rows, dst_cols), + dtype=img_src.dtype, + device=img_src.device, + ) + resized_image = lanczos_resize_triton( + img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols + ) + img_src_cpu = img_src.cpu().numpy() + img_dst_cpu = torch.zeros( + (1, img_src_cpu.shape[1], dst_rows, dst_cols), dtype=img_src.dtype, device="cpu" + ).numpy() + lanczos_resize_cpu( + img_src_cpu, img_dst_cpu, c_lanczosCoeffs.cpu().numpy(), dst_rows, dst_cols + ) + torch.testing.assert_close( + resized_image.cpu(), torch.from_numpy(img_dst_cpu), atol=1.0 / 255, rtol=0 + ) + + +def benchmark_test( + fn_ref, fn_triton, ref_args=(), triton_args=(), name="gen_fn", times=10, repeat=10 +): + import time + + print( + f"--------------------benchmark_{name} for {times * repeat} times--------------------" + ) + stream = torch.npu.current_stream() + # warm_up + stream.synchronize() + for _ in range(10): + fn_triton(*triton_args) + stream.synchronize() + + start = time.perf_counter() + for _ in range(times * repeat): + fn_triton(*triton_args) + stream.synchronize() + end = time.perf_counter() + + time_compiled = (end - start) / (times * repeat) + time_compiled *= 1000000 + + # warm_up + stream.synchronize() + for _ in range(10): + std = fn_ref(*ref_args) + stream.synchronize() + + start = time.perf_counter() + for _ in range(times * repeat): + std = fn_ref(*ref_args) + stream.synchronize() + end = time.perf_counter() + time_eager = (end - start) / (times * repeat) + time_eager *= 1000000 + + accelerated = (time_eager - time_compiled) / time_compiled * 100 + print( + f"Accelerated: {accelerated:.4f}% eager takes {time_eager:.3f} us, triton takes {time_compiled:.3f} us" + ) + + return accelerated, time_eager, time_compiled + + +if __name__ == "__main__": + c_lanczosCoeffs = torch.randn(10000, dtype=torch.float32, device="npu") / 4.0 + + src_rows, src_cols = 360, 640 + dst_rows, dst_cols = 140, 280 + img_src = torch.randn(1, 4, src_rows, src_cols, dtype=torch.float32, device="npu") + + print("==========run npu===============") + img_dst = torch.zeros( + (1, img_src.shape[1], dst_rows, dst_cols), + dtype=img_src.dtype, + device=img_src.device, + ) + resized_image = lanczos_resize_triton( + img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols + ) + resized_cpu = resized_image.cpu().numpy() + print("==========run cpu===============") + img_src_cpu = img_src.cpu().numpy() + img_dst_cpu = torch.zeros( + (1, img_src_cpu.shape[1], dst_rows, dst_cols), dtype=img_src.dtype, device="cpu" + ).numpy() + lanczos_resize_cpu( + img_src_cpu, img_dst_cpu, c_lanczosCoeffs.cpu().numpy(), dst_rows, dst_cols + ) + + print("==========compare result===============") + diff = np.abs(resized_cpu - img_dst_cpu) + max_diff_value = np.max(diff) + print("max diff float = ", max_diff_value) + print("max diff * 255 int = ", int(max_diff_value * 255)) + torch.testing.assert_close( + resized_image.cpu(), torch.from_numpy(img_dst_cpu), atol=1.0 / 255, rtol=0 + ) + + print("==========profiling===============") + accelerate, eager_time, triton_time = benchmark_test( + lanczos_resize_cpu, + lanczos_resize_triton, + ref_args=( + img_src_cpu, + img_dst_cpu, + c_lanczosCoeffs.cpu().numpy(), + dst_rows, + dst_cols, + ), + triton_args=(img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols), + name="lanzcos", + ) diff --git a/third_party/ascend/examples/pytest_ut/test_launcher_empty_signature.py b/third_party/ascend/examples/pytest_ut/test_launcher_empty_signature.py new file mode 100644 index 000000000..d3a1b61a9 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_launcher_empty_signature.py @@ -0,0 +1,17 @@ +import os +import pytest + +import triton +import triton.language as tl + + +@triton.jit +def _empty_kernel(): + return + + +@pytest.mark.interpreter +def test_launcher_empty_signature(): + grid = (1,) + _empty_kernel[grid]() + assert True diff --git a/third_party/ascend/examples/pytest_ut/test_layernorm.py b/third_party/ascend/examples/pytest_ut/test_layernorm.py new file mode 100644 index 000000000..d536a3cc3 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_layernorm.py @@ -0,0 +1,105 @@ +import pytest +import torch +import triton +import triton.language as tl +import torch_npu + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +@torch.inference_mode() +def layer_norm(x, normalized_shape, weight, bias, eps=1e-5): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + kernel = _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + # print(kernel.asm['ttir']) + return y + + +def _layer_norm(M, N, dtype, eps=1e-5, device='npu'): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + # compare + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + print(f"layernorm {M},{N} {dtype} passed") + + +def test_layernorm(): + _layer_norm(128, 128, torch.float16) + _layer_norm(128, 128, torch.bfloat16) + _layer_norm(128, 128, torch.float32) + + # _layer_norm(128, 3, torch.bfloat16) + # _layer_norm(128, 16, torch.bfloat16) + # _layer_norm(128, 37, torch.bfloat16) + # _layer_norm(128, 781, torch.bfloat16) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_ldst.py b/third_party/ascend/examples/pytest_ut/test_ldst.py new file mode 100644 index 000000000..ed116a914 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_ldst.py @@ -0,0 +1,410 @@ +import torch, torch_npu +import triton +import triton.language as tl +import triton.language.math as tl_math +import pytest + +def test_ldst_indirect_00(): + + @triton.jit + def triton_ldst_indirect_00_kernel( + out_ptr0, in_ptr0, in_ptr1, OFFSET0: tl.constexpr, + XS: tl.constexpr + ): + pid = tl.program_id(0) + offset1 = tl.load(in_ptr0 + OFFSET0) + idx_in1 = offset1 + pid * XS + tl.arange(0, XS) + tmp0 = tl.load(in_ptr1 + idx_in1) + tmp1 = tl_math.exp(tmp0) + idx_out0 = pid * XS + tl.arange(0, XS) + tl.store(out_ptr0 + idx_out0, tmp1) + + def triton_ldst_indirect_00_func(x0, x1, s, xs): + n = x1.numel() + ns = n - s + assert ns == xs, "test only single core" + y0 = torch.empty((ns,), dtype=x1.dtype, device=x1.device) + triton_ldst_indirect_00_kernel[ns // xs, 1, 1]( + y0, x0, x1, OFFSET0 = s, XS = xs) + return y0 + + def torch_ldst_indirect_00_func(x0, x1, s): + offset = x0[s] + return torch.exp(x1[offset:]) + + DEV = "npu" + DTYPE = torch.float32 + offset = 0 + N0, N1 = 16, 16 + blocksize = 16 + assert N0 > offset, "offset must be < N0" + N1 = N1 + offset + x0 = torch.arange(0, N0, dtype=torch.int32, device=DEV) + x1 = torch.randn((N1,), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_00_func(x0, x1, offset) + triton_cal = triton_ldst_indirect_00_func(x0, x1, offset, blocksize) + torch.testing.assert_close(triton_cal, torch_ref) + +def test_ldst_indirect_01(): + + @triton.jit + def triton_ldst_indirect_01_kernel( + out_ptr0, in_ptr0, in_ptr1, OFFSET0: tl.constexpr, + XS: tl.constexpr + ): + pid = tl.program_id(0) + offset1 = tl.load(in_ptr0 + OFFSET0) + idx_in1 = offset1 + pid * XS + tl.arange(0, XS) + tmp0 = tl.load(in_ptr1 + idx_in1) + tmp1 = tl_math.exp(tmp0) + idx_out0 = pid * XS + tl.arange(0, XS) + tl.store(out_ptr0 + idx_out0, tmp1) + + def triton_ldst_indirect_01_func(x0, x1, s, xs): + n = x1.numel() + ns = n - s + assert ns == xs, "test only single core" + y0 = torch.empty((ns,), dtype=x1.dtype, device=x1.device) + triton_ldst_indirect_01_kernel[ns // xs, 1, 1]( + y0, x0, x1, OFFSET0 = s, XS = xs) + return y0 + + def torch_ldst_indirect_01_func(x0, x1, s): + offset = x0[s] + return torch.exp(x1[offset:]) + + DEV = "npu" + DTYPE = torch.float32 + offset = 0 + N0, N1 = 16, 16 + blocksize = 16 + assert N0 > offset, "offset must be < N0" + N1 = N1 + offset + x0 = torch.arange(0, N0, device=DEV) # int64 + x1 = torch.randn((N1,), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_01_func(x0, x1, offset) + triton_cal = triton_ldst_indirect_01_func(x0, x1, offset, blocksize) + torch.testing.assert_close(triton_cal, torch_ref) + +def test_ldst_indirect_02(): + + @triton.jit + def triton_ldst_indirect_02_kernel( + out_ptr0, in_ptr0, in_ptr1, + XS: tl.constexpr + ): + pid = tl.program_id(0) + for i in tl.range(0, XS): + tmp0 = tl.load(in_ptr0 + i) + tmp1 = tl.load(in_ptr1 + tmp0) + tmp2 = tl_math.exp(tmp1) + tl.store(out_ptr0 + i, tmp2) + + def triton_ldst_indirect_02_func(x0, x1, xs): + n0 = x0.numel() + assert n0 == xs, "test only single core" + y0 = torch.empty((n0,), dtype=x1.dtype, device=x1.device) + triton_ldst_indirect_02_kernel[n0 // xs, 1, 1]( + y0, x0, x1, XS = xs) + return y0 + + def torch_ldst_indirect_02_func(x0, x1): + return torch.exp(x1[x0]) + + DEV = "npu" + DTYPE = torch.float32 + offset = 8 + N0, N1 = 16, 32 + blocksize = 16 + assert N1 >= N0+offset, "N1 must be >= N0+offset" + assert N0 == blocksize, "N0 must be == blocksize" + x0 = offset + torch.arange(0, N0, device=DEV) # int64 + x1 = torch.randn((N1,), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_02_func(x0, x1) + triton_cal = triton_ldst_indirect_02_func(x0, x1, blocksize) + torch.testing.assert_close(triton_cal, torch_ref) + +def test_ldst_indirect_03(): + + @triton.jit + def triton_ldst_indirect_03_kernel( + out_ptr0, in_ptr0, in_ptr1, + XS: tl.constexpr + ): + pid = tl.program_id(0) + in_idx0 = pid * XS + tl.arange(0, XS) + tmp0 = tl.load(in_ptr0 + in_idx0) + tmp1 = tl.load(in_ptr1 + tmp0) + tmp2 = tl_math.exp(tmp1) + out0_idx = pid * XS + tl.arange(0, XS) + tl.store(out_ptr0 + out0_idx, tmp2) + + def triton_ldst_indirect_03_func(x0, x1, xs): + n0 = x0.numel() + assert n0 == xs, "test only single core" + y0 = torch.empty((n0,), dtype=x1.dtype, device=x1.device) + triton_ldst_indirect_03_kernel[n0 // xs, 1, 1]( + y0, x0, x1, XS = xs) + return y0 + + def torch_ldst_indirect_03_func(x0, x1): + return torch.exp(x1[x0]) + + DEV = "npu" + DTYPE = torch.float32 + offset = 8 + N0, N1 = 16, 32 + blocksize = 16 + assert N1 >= N0+offset, "N1 must be >= N0+offset" + assert N0 == blocksize, "N0 must be == blocksize" + x0 = offset + torch.arange(0, N0, device=DEV) # int64 + x1 = torch.randn((N1,), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_03_func(x0, x1) + triton_cal = triton_ldst_indirect_03_func(x0, x1, blocksize) + torch.testing.assert_close(triton_cal, torch_ref) + +def test_ldst_indirect_04(): + + @triton.jit + def triton_ldst_indirect_04_kernel( + out_ptr0, in_ptr0, in_ptr1, + XS: tl.constexpr + ): + pid = tl.program_id(0) + in_idx0 = pid * XS + tl.arange(0, XS) + tmp0 = tl.load(in_ptr0 + in_idx0) + tmp0min = tl.min(tmp0, axis=0) + tmp0max = tl.max(tmp0, axis=0) + tmp0 = tmp0 * 2.0 + tmp0 = tl.clamp(tmp0, tmp0min, tmp0max) + tmp0 = tmp0.to(tl.int32) + tmp1 = tl.load(in_ptr1 + tmp0) + tmp2 = tl_math.exp(tmp1) + out0_idx = pid * XS + tl.arange(0, XS) + tl.store(out_ptr0 + out0_idx, tmp2) + + def triton_ldst_indirect_04_func(x0, x1, xs): + n0 = x0.numel() + assert n0 == xs, "test only single core" + y0 = torch.empty((n0,), dtype=x1.dtype, device=x1.device) + triton_ldst_indirect_04_kernel[n0 // xs, 1, 1]( + y0, x0, x1, XS = xs) + return y0 + + def torch_ldst_indirect_04_func(x0, x1): + x0min = torch.min(x0) + x0max = torch.max(x0) + idx = torch.clamp(x0*2, x0min, x0max) + return torch.exp(x1[idx.to(torch.int32)]) + + DEV = "npu" + DTYPE = torch.float32 + offset = 8 + N0, N1 = 16, 32 + blocksize = 16 + assert N1 >= N0+offset, "N1 must be >= N0+offset" + assert N0 == blocksize, "N0 must be == blocksize" + x0 = offset + torch.arange(0, N0, dtype=torch.float32, device=DEV) + x1 = torch.randn((N1,), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_04_func(x0, x1) + triton_cal = triton_ldst_indirect_04_func(x0, x1, blocksize) + torch.testing.assert_close(triton_cal, torch_ref) + +def test_ldst_indirect_05(): + + @triton.jit + def triton_ldst_indirect_05_kernel( + out_ptr0, in_ptr1, in_ptr2, stride_in_r, + XS: tl.constexpr, RS: tl.constexpr + ): + pid = tl.program_id(0) + in_idx0 = pid * XS + tl.arange(0, XS) + in_idx1 = tl.arange(0, RS) + tmp0 = tl.arange(0, XS) + tmp1 = tl.load(in_ptr1 + in_idx1) + in_idx2 = tmp0[:, None] * stride_in_r + tmp1[None, :] + tmp2 = tl.load(in_ptr2 + in_idx2) + tmp2 = tl_math.exp(tmp2) + out0_idx = in_idx0[:, None] * RS + in_idx1[None, :] + tl.store(out_ptr0 + out0_idx, tmp2) + + def triton_ldst_indirect_05_func(xc, x2, xs, rs): + nr = x2.size()[0] + nc = xc.numel() + stride_in_r = x2.stride()[0] + assert nr == xs, "test only single core" + y0 = torch.empty((nr, nc), dtype=x2.dtype, device=x2.device) + triton_ldst_indirect_05_kernel[nr // xs, 1, 1]( + y0, xc, x2, stride_in_r, XS = xs, RS = rs) + return y0 + + def torch_ldst_indirect_05_func(xr, xc, x2): + flatten_idx = (xr[:, None] * x2.stride()[0] + xc[None, :]).flatten() + extracted = x2.flatten()[flatten_idx].reshape([xr.numel(), xc.numel()]) + return torch.exp(extracted) + + DEV = "npu" + DTYPE = torch.float32 + offset = 8 + N0, N1 = 16, 32 + blocksize = 8 + lowdimsize = N0 + assert N1 >= N0+offset, "N1 must be >= N0+offset" + assert N0 == lowdimsize, "N0 must be == lowdimsize" + xc = offset + torch.arange(0, N0, device=DEV) + xr = torch.arange(0, blocksize, device=DEV) + x2 = torch.randn((blocksize, N1), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_05_func(xr, xc, x2) + triton_cal = triton_ldst_indirect_05_func(xc, x2, blocksize, lowdimsize) + torch.testing.assert_close(triton_cal, torch_ref) + +def test_ldst_indirect_06(): + + @triton.jit + def triton_ldst_indirect_06_kernel( + out_ptr0, in_ptr0, in_ptr1, in_ptr2, stride_in_r, + XS: tl.constexpr, RS: tl.constexpr + ): + pid = tl.program_id(0) + in_idx0 = pid * XS + tl.arange(0, XS) + in_idx1 = tl.arange(0, RS) + tmp0 = tl.load(in_ptr0 + in_idx0) + tmp1 = tl.load(in_ptr1 + in_idx1) + in_idx2 = tmp0[:, None] * stride_in_r + tmp1[None, :] + tmp2 = tl.load(in_ptr2 + in_idx2) + tmp2 = tl_math.exp(tmp2) + out0_idx = in_idx0[:, None] * RS + in_idx1[None, :] + tl.store(out_ptr0 + out0_idx, tmp2) + + def triton_ldst_indirect_06_func(xr, xc, x2, xs, rs): + nr = x2.size()[0] + nc = xc.numel() + stride_in_r = x2.stride()[0] + assert nr == xs, "test only single core" + y0 = torch.empty((nr, nc), dtype=x2.dtype, device=x2.device) + triton_ldst_indirect_06_kernel[nr // xs, 1, 1]( + y0, xr, xc, x2, stride_in_r, XS = xs, RS = rs) + return y0 + + def torch_ldst_indirect_06_func(xr, xc, x2): + flatten_idx = (xr[:, None] * x2.stride()[0] + xc[None, :]).flatten() + extracted = x2.flatten()[flatten_idx].reshape([xr.numel(), xc.numel()]) + return torch.exp(extracted) + + DEV = "npu" + DTYPE = torch.float32 + offset = 8 + N0, N1 = 16, 32 + blocksize = 4 + lowdimsize = N0 + assert N1 >= N0+offset, "N1 must be >= N0+offset" + assert N0 == lowdimsize, "N0 must be == lowdimsize" + xc = offset + torch.arange(0, N0, device=DEV) + xr = torch.arange(0, blocksize, device=DEV) + x2 = torch.randn((blocksize, N1), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_06_func(xr, xc, x2) + triton_cal = triton_ldst_indirect_06_func(xr, xc, x2, blocksize, lowdimsize) + torch.testing.assert_close(triton_cal, torch_ref) + +def test_ldst_indirect_07(): + + @triton.jit + def triton_ldst_indirect_07_kernel( + out_ptr0, in_ptr0, in_ptr1, in_ptr2, stride_in_r, + XS: tl.constexpr, RS: tl.constexpr + ): + pid = tl.program_id(0) + in_idx0 = pid * XS + tl.arange(0, XS) + in_idx1 = tl.arange(0, RS) + tmp0 = tl.load(in_ptr0 + in_idx0) + tmp1 = tl.load(in_ptr1 + in_idx1) + in_idx2 = tmp0[:, None] * stride_in_r + tmp1[None, :] + tmp2 = tl.load(in_ptr2 + in_idx2) + out0_idx = in_idx0[:, None] * RS + in_idx1[None, :] + tl.store(out_ptr0 + out0_idx, tmp2) + + def triton_ldst_indirect_07_func(xr, xc, x2, xs, rs): + nr = x2.size()[0] + nc = xc.numel() + stride_in_r = x2.stride()[0] + assert nr == xs, "test only single core" + y0 = torch.empty((nr, nc), dtype=x2.dtype, device=x2.device) + triton_ldst_indirect_07_kernel[nr // xs, 1, 1]( + y0, xr, xc, x2, stride_in_r, XS = xs, RS = rs) + return y0 + + def torch_ldst_indirect_07_func(xr, xc, x2): + flatten_idx = (xr[:, None] * x2.stride()[0] + xc[None, :]).flatten() + extracted = x2.flatten()[flatten_idx].reshape([xr.numel(), xc.numel()]) + return extracted + + DEV = "npu" + DTYPE = torch.float32 + offset = 8 + N0, N1 = 16, 32 + blocksize = 4 + lowdimsize = N0 + assert N1 >= N0+offset, "N1 must be >= N0+offset" + assert N0 == lowdimsize, "N0 must be == lowdimsize" + xc = offset + torch.arange(0, N0, device=DEV) + xr = torch.arange(0, blocksize, device=DEV) + x2 = torch.randn((blocksize, N1), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_07_func(xr, xc, x2) + triton_cal = triton_ldst_indirect_07_func(xr, xc, x2, blocksize, lowdimsize) + torch.testing.assert_close(triton_cal, torch_ref) + +@pytest.mark.skip(reason="Indirect store to be supported") +def test_ldst_indirect_08(): + + @triton.jit + def triton_ldst_indirect_08_kernel( + out_ptr0, in_ptr1, in_ptr2, in_ptr3, stride_in_r, + XS: tl.constexpr, RS: tl.constexpr + ): + pid = tl.program_id(0) + in_idx0 = pid * XS + tl.arange(0, XS) + in_idx1 = tl.arange(0, RS) + tmp0 = tl.arange(0, XS) + tmp1 = tl.load(in_ptr1 + in_idx1) + in_idx2 = tmp0[:, None] * stride_in_r + tmp1[None, :] + tmp2 = tl.load(in_ptr2 + in_idx2) + tmp2 = tl_math.exp(tmp2) + tmp3 = tl.load(in_ptr3 + in_idx1) + tmp3 = tmp3 + 1 + out0_idx = in_idx0[:, None] * RS + tmp3[None, :] + tl.store(out_ptr0 + out0_idx, tmp2) + + def triton_ldst_indirect_08_func(xc, x2, xs, rs): + nr = x2.size()[0] + nc = xc.numel() + stride_in_r = x2.stride()[0] + assert nr == xs, "test only single core" + y0 = torch.empty((nr, nc), dtype=x2.dtype, device=x2.device) + xc1 = xc - 1 + triton_ldst_indirect_08_kernel[nr // xs, 1, 1]( + y0, xc, x2, xc1, stride_in_r, XS = xs, RS = rs) + return y0 + + def torch_ldst_indirect_08_func(xr, xc, x2): + flatten_idx = (xr[:, None] * x2.stride()[0] + xc[None, :]).flatten() + extracted = x2.flatten()[flatten_idx].reshape([xr.numel(), xc.numel()]) + return torch.exp(extracted) + + DEV = "npu" + DTYPE = torch.float32 + offset = 8 + N0, N1 = 16, 32 + blocksize = 8 + lowdimsize = N0 + assert N1 >= N0+offset, "N1 must be >= N0+offset" + assert N0 == lowdimsize, "N0 must be == lowdimsize" + xc = offset + torch.arange(0, N0, device=DEV) + xr = torch.arange(0, blocksize, device=DEV) + x2 = torch.randn((blocksize, N1), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_08_func(xr, xc, x2) + triton_cal = triton_ldst_indirect_08_func(xc, x2, blocksize, lowdimsize) + torch.testing.assert_close(triton_cal, torch_ref) + +if __name__ == "__main__": + test_ldst_indirect_05() + print("success: test_ldst_indirect_05") \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_le.py b/third_party/ascend/examples/pytest_ut/test_le.py new file mode 100644 index 000000000..5add0627d --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_le.py @@ -0,0 +1,55 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +def standard_binary(x0, y0): + res = x0 <= y0 + return res + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x <= y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + (torch.int8, 'int8'), + (torch.int16, 'int16'), + (torch.int32, 'int32'), + (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + y0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + ans = standard_binary(x0, y0) + out = torch.zeros((N,), dtype=torch.bool).npu() + triton_elementwise_binary[1, 1, 1](x0, y0, out, N, NUMEL) + test_common.validate_cmp(sigtype, out, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_load.py b/third_party/ascend/examples/pytest_ut/test_load.py new file mode 100644 index 000000000..55bec58c6 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_load.py @@ -0,0 +1,102 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + +# eg: pytest -v test.py::test_add +############################# + +@triton.jit +def triton_load_store(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp2 = tmp0 + tl.store(out_ptr0 + (xindex), tmp2, xmask) + + +# require: all data (4d and 5d) can be placed into but without ub overflow +@triton.jit +def triton_load_store_multi_d( + in_ptr0, out_ptr0, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + tmp_in = tl.load(in_ptr0 + offsets, masks) + tmp_out = tmp_in + tl.store(out_ptr0 + offsets, tmp_out, masks) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ['float32', (8, 8, 4), 2, 128, 64], + ['float16', (8, 8, 4), 2, 128, 64], + ['int8', (8, 8, 4), 2, 128, 64], + ['int8', (8, 7, 4), 2, 128, 64], + + ] + ) +def test_load_store(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0 + y_cal = test_common.generate_tensor(shape, dtype).npu() + triton_load_store[(ncore, )](x0, y_cal, x0.numel(), xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (8, 4, 16, 16)], + ['float16', (8, 4, 16, 16)], + ['int8', (8, 4, 16, 16)], + ['float32', (8, 8, 4, 4)], + ['float16', (8, 8, 4, 4)], + ['int8', (8, 8, 4, 4)], + ['float32', (3, 8, 2, 16, 16)], + ['float16', (3, 8, 2, 16, 16)], + ['int8', (9, 8, 8, 16, 16)], + ['float32', (11, 8, 8, 4, 4)], + ['float16', (11, 8, 8, 4, 4)], + ['int8', (11, 8, 8, 4, 4)], + ] + ) +def test_load_store_multi_d(param_list): + dtype, shape = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_expect = x0 + y_actual = test_common.generate_tensor(shape, dtype).npu() + + blocks = list(x0.size()) + shapes = list(x0.stride()) + while len(blocks) < 5: + blocks.append(1) + shapes.append(1) + triton_load_store_multi_d[(1, )](x0, y_actual, *blocks, *blocks, *shapes) + test_common.validate_cmp(dtype, y_actual, y_expect) diff --git a/third_party/ascend/examples/pytest_ut/test_load_store.py b/third_party/ascend/examples/pytest_ut/test_load_store.py new file mode 100644 index 000000000..23761ca2e --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_load_store.py @@ -0,0 +1,138 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +@triton.jit +def triton_load_store(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + x_index = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = x_index < xnumel + tmp0 = tl.load(in_ptr0 + x_index, xmask) + tmp2 = tmp0 + tl.store(out_ptr0 + x_index, tmp2, xmask) + + +# require: all data (4d and 5d) can be placed into but without ub overflow +@triton.jit +def triton_load_store_multi_d( + in_ptr0, out_ptr0, + BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, BLOCK_2: tl.constexpr, BLOCK_3: tl.constexpr, BLOCK_4: tl.constexpr, + SHAPE_0: tl.constexpr, SHAPE_1: tl.constexpr, SHAPE_2: tl.constexpr, SHAPE_3: tl.constexpr, SHAPE_4: tl.constexpr, + STRIDE_0: tl.constexpr, STRIDE_1: tl.constexpr, STRIDE_2: tl.constexpr, STRIDE_3: tl.constexpr, STRIDE_4: tl.constexpr +): + offsets = tl.program_id(0) + + offsets = offsets + tl.arange(0, BLOCK_0) * STRIDE_0 + masks = tl.arange(0, BLOCK_0) < SHAPE_0 + if (BLOCK_1 * BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, None] + tl.arange(0, BLOCK_1)[None, :] * STRIDE_1 + masks = masks[:, None] & (tl.arange(0, BLOCK_1)[None, :] < SHAPE_1) + if (BLOCK_2 * BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, None] + tl.arange(0, BLOCK_2)[None, None, :] * STRIDE_2 + masks = masks[:, :, None] & (tl.arange(0, BLOCK_2)[None, None, :] < SHAPE_2) + if (BLOCK_3 * BLOCK_4) > 1: + offsets = offsets[:, :, :, None] + tl.arange(0, BLOCK_3)[None, None, None, :] * STRIDE_3 + masks = masks[:, :, :, None] & (tl.arange(0, BLOCK_3)[None, None, None, :] < SHAPE_3) + if BLOCK_4 > 1: + offsets = offsets[:, :, :, :, None] + tl.arange(0, BLOCK_4)[None, None, None, None, :] * STRIDE_4 + masks = masks[:, :, :, :, None] & (tl.arange(0, BLOCK_4)[None, None, None, None, :] < SHAPE_4) + + tmp_in = tl.load(in_ptr0 + offsets, masks) + tmp_out = tmp_in + tl.store(out_ptr0 + offsets, tmp_out, masks) + + +@triton.jit +def triton_load_store_sle_mask(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + x_index = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = x_index <= xnumel + tmp0 = tl.load(in_ptr0 + x_index, xmask) + tmp2 = tmp0 + tl.store(out_ptr0 + x_index, tmp2, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ['float32', (8, 8, 4), 2, 128, 64], + ['float16', (8, 8, 4), 2, 128, 64], + ['int8', (8, 8, 4), 2, 128, 64], + ] + ) +def test_load_store(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + y_ref = x0 + # triton结果 + y_cal = test_common.generate_tensor(shape, dtype).npu() + triton_load_store[(ncore, )](x0, y_cal, x0.numel(), xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, y_cal, y_ref) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (8, 4, 16, 16)], + ['float16', (8, 4, 16, 16)], + ['int8', (8, 4, 16, 16)], + ['float32', (8, 8, 4, 4)], + ['float16', (8, 8, 4, 4)], + ['int8', (8, 8, 4, 4)], + ['float32', (3, 8, 2, 16, 16)], + ['float16', (3, 8, 2, 16, 16)], + ['int8', (9, 8, 8, 16, 16)], + ['float32', (11, 8, 8, 4, 4)], + ['float16', (11, 8, 8, 4, 4)], + ['int8', (11, 8, 8, 4, 4)], + ] + ) +def test_load_store_multi_d(param_list): + # 生成数据 + dtype, shape = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + y_expect = x0 + y_actual = test_common.generate_tensor(shape, dtype).npu() + # triton结果 + blocks = list(x0.size()) + shapes = list(x0.stride()) + while len(blocks) < 5: + blocks.append(1) + shapes.append(1) + triton_load_store_multi_d[(1, )](x0, y_actual, *blocks, *blocks, *shapes) + # 比较结果 + test_common.validate_cmp(dtype, y_actual, y_expect) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ['float32', (8, 8, 4), 2, 128, 64], + ['float16', (8, 8, 4), 2, 128, 64], + ['int8', (8, 8, 4), 2, 128, 64], + ] + ) +def test_load_store_sle_mask(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + y_ref = x0 + # triton结果 + y_cal = test_common.generate_tensor(shape, dtype).npu() + triton_load_store_sle_mask[(ncore, )](x0, y_cal, x0.numel() - 1, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_log.py b/third_party/ascend/examples/pytest_ut/test_log.py new file mode 100644 index 000000000..b94dcfe7e --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_log.py @@ -0,0 +1,78 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +def standard_unary(x0, dtype): + res = torch.log(x0) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + ret = tl.math.log(x) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + # (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + print(out) + + test_common.validate_cmp(sigtype, out, ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_log1p.py b/third_party/ascend/examples/pytest_ut/test_log1p.py new file mode 100644 index 000000000..4647f7a4c --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_log1p.py @@ -0,0 +1,41 @@ +import pytest +import triton +import triton.language as tl +import torch +import test_common +import triton.language.extra.ascend.libdevice as libdevice + + +def torch_log1p(x0, x1): + res = x0 + torch.log1p(x1) + return res + + +@triton.jit +def triton_log1p(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + x_index = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = x_index < xnumel + tmp0 = tl.load(in_ptr0 + x_index, xmask) + tmp1 = tl.load(in_ptr1 + x_index, xmask) + tmp2 = tmp0 + libdevice.log1p(tmp1) + tl.store(out_ptr0 + x_index, tmp2, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_log1p(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_log1p(x0, x1) + # triton结果 + triton_res = test_common.generate_tensor(shape, dtype).npu() + triton_log1p[ncore, 1, 1](x0, x1, triton_res, x0.numel(), xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_log2.py b/third_party/ascend/examples/pytest_ut/test_log2.py new file mode 100644 index 000000000..4371bd020 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_log2.py @@ -0,0 +1,40 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common + + +def torch_log2(x0): + res = torch.log2(x0) + return res + + +@triton.jit +def triton_log2(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_inedx = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_inedx, None) + tmp2 = tl.log2(tmp0) + tl.store(out_ptr0 + x_inedx, tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ] + ) +def test_log2(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_log2(x0) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_log2[ncore, 1, 1](x0, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_log_2.py b/third_party/ascend/examples/pytest_ut/test_log_2.py new file mode 100644 index 000000000..fc9ccc57b --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_log_2.py @@ -0,0 +1,40 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_exp2(x0): + res = torch.pow(2, x0, out=None) + return res + + +@triton.jit +def triton_exp2(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index, None) + tmp1 = tl.exp2(tmp0) + tl.store(out_ptr0 + x_index, tmp1, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_exp2(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_exp2(x0) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_exp2[ncore, 1, 1](x0, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_logical_and.py b/third_party/ascend/examples/pytest_ut/test_logical_and.py new file mode 100644 index 000000000..f968fd695 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_logical_and.py @@ -0,0 +1,42 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_logical_and(x0, x1): + res = torch.logical_and(x0, x1) + return res + + +@triton.jit +def triton_logical_and(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index) + tmp1 = tl.load(in_ptr1 + x_index) + tmp2 = tmp0.logical_and(tmp1) + tl.store(out_ptr0 + x_index, tmp2) + + +@pytest.mark.parametrize('param_list', + [ + ['bool', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_and(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch.logical_and(x0, x1) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_logical_and[ncore, 1, 1](x0, x1, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_logical_or.py b/third_party/ascend/examples/pytest_ut/test_logical_or.py new file mode 100644 index 000000000..c863786d8 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_logical_or.py @@ -0,0 +1,42 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_logical_or(x0, x1): + res = torch.logical_or(x0, x1) + return res + + +@triton.jit +def triton_logical_or(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index) + tmp1 = tl.load(in_ptr1 + x_index) + tmp2 = tmp0.logical_or(tmp1) + tl.store(out_ptr0 + x_index, tmp2) + + +@pytest.mark.parametrize('param_list', + [ + ['bool', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_logical_or(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_logical_or(x0, x1) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_logical_or[ncore, 1, 1](x0, x1, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_lshift.py b/third_party/ascend/examples/pytest_ut/test_lshift.py new file mode 100644 index 000000000..0fd23a69b --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_lshift.py @@ -0,0 +1,79 @@ +import pytest +import triton +import triton.language as tl +import time +import test_common +import torch +import torch_npu + + +def standard_unary(x0, dtype): + res = x0 << 2 + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + tmp = tl.cast(2, tl.int8) + ret = x << tmp + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + # (torch.float32, 'float32'), + # (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + # print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + # print(out) + + test_common.validate_cmp(sigtype, out, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_lt.py b/third_party/ascend/examples/pytest_ut/test_lt.py new file mode 100644 index 000000000..e220767d6 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_lt.py @@ -0,0 +1,41 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_lt(x0, x1): + return x0 < x1 + + +@triton.jit +def triton_lt(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index, None) + tmp1 = tl.load(in_ptr1 + x_index, None) + tmp2 = tmp0 < tmp1 + tl.store(out_ptr0 + x_index, tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (32,), 1, 32, 32], + ]) +def test_lt(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_lt(x0, x1).to(eval('torch.' + dtype)) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_lt[ncore, 1, 1](x0, x1, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_makeblockptr_permute.py b/third_party/ascend/examples/pytest_ut/test_makeblockptr_permute.py new file mode 100644 index 000000000..bfb791c9a --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_makeblockptr_permute.py @@ -0,0 +1,70 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common + +@pytest.mark.parametrize('shape', [(1, 4, 2)]) +@pytest.mark.parametrize('permute_order', [(2, 0, 1)]) +def test_makeblockptr_order(shape, permute_order): + + @triton.jit + def triton_kernel(in0_ptr: tl.tensor, # of tl.pointer_type + out0_ptr: tl.tensor, # of tl.pointer_type + in0_stride0: int, in0_stride1: int, in0_stride2: int, # strides for in0 + in0_stride_order0: tl.constexpr, in0_stride_order1: tl.constexpr, in0_stride_order2: tl.constexpr, # stride order for in0 + out0_stride0: int, out0_stride1: int, out0_stride2: int, # strides for out0 + out0_stride_order0: tl.constexpr, out0_stride_order1: tl.constexpr, out0_stride_order2: tl.constexpr, # stride order for out0 + s0: int, s1: int, s2: int, + tile_size0: tl.constexpr, tile_size1: tl.constexpr, tile_size2: tl.constexpr, + ): + tile_id0 = tl.program_id(axis=0) + tile_id1 = tl.program_id(axis=1) + tile_id2 = tl.program_id(axis=2) + offset0 = (tile_id0 * tile_size0).to(tl.int32) + offset1 = (tile_id1 * tile_size1).to(tl.int32) + offset2 = (tile_id2 * tile_size2).to(tl.int32) + in0_bptr = tl.make_block_ptr(in0_ptr, + (s0, s1, s2), + (in0_stride0, in0_stride1, in0_stride2), + (offset0, offset1, offset2), + (tile_size0, tile_size1, tile_size2), + order=(in0_stride_order0, in0_stride_order1, in0_stride_order2)) + in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0, in0_stride_order1, in0_stride_order2)).to(in0_ptr.type.element_ty) + + out0 = in0 + + out0_bptr = tl.make_block_ptr(out0_ptr, (s0, s1, s2), (out0_stride0, out0_stride1, out0_stride2), (offset0, offset1, offset2), (tile_size0, tile_size1, tile_size2), + order=(out0_stride_order0, out0_stride_order1, out0_stride_order2)) + tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0, out0_stride_order1, out0_stride_order2)) + + def triton_func(in0: torch.Tensor, permute_order): + # in fact, it adjusts the layout metadata instead of doing a real permutation. + in0_permuted_tmp = in0.permute(permute_order) + in0_permuted_shape = in0_permuted_tmp.size() + in0_permuted_strides = in0_permuted_tmp.stride() + in0_stride_order = [len(permute_order)-1-i for i in permute_order] + shape = (in0_permuted_shape[0], in0_permuted_shape[1], in0_permuted_shape[2]) + tile_sizes = (shape[0], shape[1], shape[2]) + out0 = torch.empty(shape, dtype=in0.dtype, device=in0.device) + out0_strides = out0.stride() + out0_stride_order = [len(permute_order)-1-i for i in range(len(permute_order))] + grid = (shape[0]//tile_sizes[0], shape[1]//tile_sizes[1], shape[2]//tile_sizes[2]) + triton_kernel[grid]( + in0, out0, + in0_permuted_strides[0], in0_permuted_strides[1], in0_permuted_strides[2], # stride for in0 + in0_stride_order[0], in0_stride_order[1], in0_stride_order[2], # stride order for in0 + out0_strides[0], out0_strides[1], out0_strides[2], # stride for out0 + out0_stride_order[0], out0_stride_order[1], out0_stride_order[2], # stride orderfor out0 + shape[0], shape[1], shape[2], # task indexing space + tile_size0=tile_sizes[0], + tile_size1=tile_sizes[1], + tile_size2=tile_sizes[2], + ) + return out0 + + x0 = torch.randint(0, 9, shape, dtype=torch.int32).npu() + torch_ref = torch.permute(x0, permute_order) + triton_cal = triton_func(x0, permute_order) + test_common.validate_cmp("int32", triton_cal, torch_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_max_dim0.py b/third_party/ascend/examples/pytest_ut/test_max_dim0.py new file mode 100644 index 000000000..5491ef5eb --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_max_dim0.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common + +def standard_max(x0,dim,dtype): + (res, maxindex) = torch.max(x0, dim) + return res + +@triton.jit +def triton_max_dim0(in_ptr0, out_ptr0, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + + mmask = mblk_idx 0: + # scale X beforehand to avoid overflow + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + if label_smoothing > 0: + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + z_loss = lse_square_scale * lse * lse + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + z_loss = z_loss / n_non_ignore + loss += z_loss + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS: + tl.store(z_loss_ptr, z_loss) + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) + + +MAX_FUSED_SIZE = 65536 // 2 + + +def cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() + assert (target * target_mask).max() < _input.shape[-1], ( + f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}" + ) + assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0" + sum_non_ignore_weight = n_non_ignore + weight_sum = 0.0 + if weight is not None: + assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point(weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + ) + sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() + weight_sum = weight.sum().item() + # ensure weight is contiguous + if weight.stride(-1) != 1: + weight = weight.contiguous() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + weight_ptr=weight, # dummy if None + loss_ptr=loss_1d, + z_loss_ptr=z_loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + sum_non_ignore_weight=sum_non_ignore_weight, + ignore_index=ignore_index, + weight_sum=weight_sum, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + ) + + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + + return loss, z_loss, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + # If reduction is 'none' + elif grad_output.ndim > 0: + _input = _input * grad_output.unsqueeze(dim=1) + # If reduction is ['mean', 'sum'], grad_output is just a scalar + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return _input + + +class LigerCrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Liger Cross Entropy loss. + It overrides the forward and backward methods of the torch.autograd.Function class. + """ + + @staticmethod + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.FloatTensor], + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + ): + """ + The forward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. + target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index (int): The index to ignore in the target. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False` + + Returns: + tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. + """ + loss, z_loss, _input = cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + ) + # If we don't detach the _input tensor, the memory will double + # Not sure why but seems that there will be a time both grad and value exist but in different location + ctx.save_for_backward(_input.detach()) + ctx.return_z_loss = return_z_loss + + return loss, z_loss + + @staticmethod + def backward(ctx, grad_output, grad_ouput2): + """ + The backward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + grad_output2 (tenosr): No use. + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + if ctx.return_z_loss: + del grad_ouput2 # z_loss is only for logging + + (_input,) = ctx.saved_tensors + _input = cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def liger_cross_entropy( + input, + target, + weight=None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + label_smoothing: float = 0.0, + lse_square_scale: float = 0.0, + softcap: Optional[float] = None, + return_z_loss: bool = False, +): + loss, z_loss = LigerCrossEntropyFunction.apply( + input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + ) + if not return_z_loss: + return loss + return loss, z_loss + + +def _test_correctness_functional( + B, + T, + V, + scalar, + dtype, + atol, + rtol, +): + _input = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + + x1 = _input.clone().requires_grad_(True) + x2 = _input.clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + y1, y1_z = liger_cross_entropy( + x1, + target, + None, + ignore_index=0, + lse_square_scale=1e-4, + label_smoothing=0.1, + reduction="mean", + softcap=30.0, + return_z_loss=True, + ) + y2, y2_z = LigerCrossEntropyFunction.apply(x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True) + + assert torch.allclose(y1, y2, atol=atol, rtol=rtol) + assert torch.allclose(y1_z, y2_z, atol=atol, rtol=rtol) + + grad = torch.randn_like(y2) + + y1.backward(grad) + y2.backward(grad) + + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 512, 4096), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 1e-8, 5e-2), + (1.0, torch.float32, 1e-8, 1e-6), + (1.0, torch.int32, 0, 0), + ], +) +def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): + _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol) diff --git a/third_party/ascend/examples/pytest_ut/test_nearest.py b/third_party/ascend/examples/pytest_ut/test_nearest.py new file mode 100644 index 000000000..de734ace8 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_nearest.py @@ -0,0 +1,105 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import math +import numpy as np +import pytest + +@triton.jit +def nearest_resize_kernel( + img_src_ptr, img_dst_ptr, src_rows, src_cols, dst_rows, dst_cols, + RR_H, RR_W, C, + stride_in_h, stride_in_w, stride_in_c, + stride_out_h, stride_out_w, stride_out_c, + BLOCK_SIZE: tl.constexpr +): + #RR_H和RR_W分别为高和宽的缩放比例 + block_id_c = tl.program_id(0) + block_id_h = tl.program_id(1) + block_id_w = tl.program_id(2) + dest_h_offs = ( + block_id_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + ) + dest_w_offs = ( + block_id_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + ) + dest_offs = ( + block_id_c[ None, None] * stride_out_c + + dest_h_offs[ :, None] * stride_out_h + + dest_w_offs[ None, :] * stride_out_w + ) + #根据output image的坐标值(dest_h_offs, dest_w_offs)计算input image的坐标值(sy, sx) + fy = dest_h_offs * RR_H + sy = tl.floor(fy) + fx = dest_w_offs * RR_W + sx = tl.floor(fx) + + src_offsets = ( + block_id_c[None, None] * stride_in_c + + tl.clamp(sy, 0, src_rows -1)[:, None].to(tl.int32) * stride_in_h + + tl.clamp(sx, 0, src_cols -1)[None, :].to(tl.int32) * stride_in_w) + src_val = tl.load(img_src_ptr + src_offsets) + dst_mask = (dest_h_offs[ :, None] < dst_rows) & (dest_w_offs[None, :] < dst_cols) + tl.store(img_dst_ptr + dest_offs, src_val, mask=dst_mask) + +def triton_kernel(img_src, img_dst): + N, C, src_rows, src_cols = img_src.shape + _, _, dst_rows, dst_cols = img_dst.shape + R_H = float(dst_rows) / src_rows + R_W = float(dst_cols) / src_cols + RR_H = 1.0 / R_H + RR_W = 1.0 / R_W + stride_in_n, stride_in_c, stride_in_h, stride_in_w = img_src.stride() + stride_out_n, stride_out_c, stride_out_h, stride_out_w = img_dst.stride() + bs = 16 + grid = lambda meta: ( + C, + triton.cdiv(dst_rows, meta["BLOCK_SIZE"]), + triton.cdiv(dst_cols, meta["BLOCK_SIZE"]), + ) + nearest_resize_kernel[grid]( + img_src, img_dst, src_rows, src_cols, dst_rows, dst_cols, + RR_H, RR_W, C, + stride_in_h, stride_in_w, stride_in_c, + stride_out_h, stride_out_w, stride_out_c, + bs) + return img_dst + +def nearest_resize_cpu(img_src, img_dst): + N, C, src_rows, src_cols = img_src.shape + _, _, dst_rows, dst_cols = img_dst.shape + #RR_H和RR_W分别为高和宽的缩放比例 + RR_H = src_rows / float(dst_rows) + RR_W = src_cols / float(dst_cols) + #根据output image的坐标值(i,j)计算input image的坐标值(sy, sx) + for i in range(dst_rows): + for j in range(dst_cols): + fy = (i * RR_H) + sy = math.floor(fy) + fx = (j * RR_W) + sx = math.floor(fx) + src_val = img_src[0, :, np.clip(sy, 0, src_rows -1), np.clip(sx, 0, src_cols -1)] + img_dst[0, :, i, j] = src_val + return img_dst + +@pytest.mark.parametrize("shapes", [[360, 640, 140, 280],]) +def test_nearest(shapes): + src_rows, src_cols, dst_rows, dst_cols = shapes + img_src = torch.rand(1, 4, src_rows, src_cols, dtype=torch.float32, device='npu') + img_dst = torch.zeros((1, img_src.shape[1], dst_rows, dst_cols), dtype=img_src.dtype, device=img_src.device) + torch_ref = nearest_resize_cpu(img_src.cpu(), img_dst.cpu()) + triton_cal = triton_kernel(img_src, img_dst) + torch.testing.assert_close(torch_ref.npu(), triton_cal) + +if __name__ == "__main__": + src_rows, src_cols = 360, 640 + dst_rows, dst_cols = 140, 280 + img_src = torch.rand(1, 4, src_rows, src_cols, dtype=torch.float32, device='npu') + img_dst = torch.zeros((1, img_src.shape[1], dst_rows, dst_cols), dtype=img_src.dtype, device=img_src.device) + + assert img_src.shape[0] == 1, "currently supports only shape[0] == 1 which does not change the functionality of thie case" + torch_ref = nearest_resize_cpu(img_src.cpu(), img_dst.cpu()) + triton_cal = triton_kernel(img_src, img_dst) + torch.testing.assert_close(torch_ref.npu(), triton_cal) + print("success") \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_neg.py b/third_party/ascend/examples/pytest_ut/test_neg.py new file mode 100644 index 000000000..7b141191e --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_neg.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common + +def torch_neg(x0): + res = -x0 + return res + +@triton.jit +def triton_neg(in_ptr0, out_ptr0, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1 = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = -tmp0 + tl.store(out_ptr0 + (x0), tmp1, None) + +@pytest.mark.parametrize('param_list', + [ + ['float16', (8, 8), 8, 8, 8], + ['float32', (8, 8), 8, 8, 8], + ['int8', (2, 4096, 8), 32, 2048, 64], + ]) + +def test_neg(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_neg(x0) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + triton_neg[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_not.py b/third_party/ascend/examples/pytest_ut/test_not.py new file mode 100644 index 000000000..ff710d6a7 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_not.py @@ -0,0 +1,44 @@ +import pytest +import triton +import triton.language as tl +import torch +import torch_npu +import test_common + + +def torch_not(x0): + res = torch.bitwise_not(x0) + return res + + +@triton.jit +def triton_not(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp2 = not(tmp0) + tl.store(out_ptr0 + (xindex), tmp2, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['int8', (2, 4096, 8), 2, 32768, 1024], + ['int16', (2, 4096, 8), 2, 32768, 1024], + ['int32', (2, 4096, 8), 2, 32768, 1024], + ['int64', (2, 4096, 8), 2, 32768, 1024], + ['bool', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_not(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_not(x0) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_not[ncore, 1, 1](x0, triton_res, x0.numel(), xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_npu_indexing.py b/third_party/ascend/examples/pytest_ut/test_npu_indexing.py new file mode 100644 index 000000000..fd078dcda --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_npu_indexing.py @@ -0,0 +1,89 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import time + + +def foo(a, b, c): + Z, Y, X, R = (1, 1, 64, 64) + y = a + b + y = y.sum(-1) + y = y.unsqueeze(3) + y = y.broadcast_to(Z, Y, X, R) + b + y = c + y.permute(0, 1, 3, 2) + return y + + +@triton.jit +def triton_foo(in_ptr0, in_ptr1, in_ptr2, out_ptr0, BLOCK1: tl.constexpr, BLOCK1_SUB: tl.constexpr, + BLOCK2: tl.constexpr, + Z: tl.constexpr, Y: tl.constexpr, X: tl.constexpr, R: tl.constexpr, + Z_STRIDE: tl.constexpr, Y_STRIDE: tl.constexpr, X_STRIDE: tl.constexpr, R_STRIDE: tl.constexpr, + Z_STRIDE1: tl.constexpr, Y_STRIDE1: tl.constexpr, X_STRIDE1: tl.constexpr, R_STRIDE1: tl.constexpr, + ): + offset: tl.constexpr = tl.program_id(0) * BLOCK1 + base1 = tl.arange(0, BLOCK1_SUB) + base2 = tl.arange(0, BLOCK2) + nsub: tl.constexpr = BLOCK1 // BLOCK1_SUB + # loops1 : tl.constexpr = nsub * Y * Z + loops1: tl.constexpr = nsub + loops2: tl.constexpr = R // BLOCK2 + + for z in range(Z): + for y in range(Y): + for loop1 in range(loops1): + # y = (loop1 // nsub) % Y + # z = loop1 // nsub // Y + # off1 = (loop1 % nsub) + off1 = loop1 + x = offset + (off1 * BLOCK1_SUB) + base1[:, None] + x1 = offset + (off1 * BLOCK1_SUB) + base1[None, :] + _tmp4 = tl.full([BLOCK1_SUB, BLOCK2], 0, tl.float32) + for loop2 in range(loops2): + r = loop2 * BLOCK2 + base2[None, :] + tmp0 = tl.load(in_ptr0 + (R_STRIDE * r + (X_STRIDE * x) + (Y_STRIDE * y) + (Z_STRIDE * z)), None) + tmp1 = tl.load(in_ptr1 + (R_STRIDE * r + (X_STRIDE * x) + (Y_STRIDE * y) + (Z_STRIDE * z)), None) + tmp2 = tmp0 + tmp1 + _tmp4 = _tmp4 + tmp2 + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp5 = tmp4.reshape(BLOCK1_SUB, 1).broadcast_to(BLOCK1_SUB, BLOCK2) + + for loop2 in range(loops2): + r = loop2 * BLOCK2 + base2[None, :] + tmp6 = tl.load(in_ptr1 + (R_STRIDE * r + (X_STRIDE * x) + (Y_STRIDE * y) + (Z_STRIDE * z)), None) + tmp7 = tmp6 + tmp5 + r1 = loop2 * BLOCK2 + base2[:, None] + tmp8 = tl.load(in_ptr2 + (R_STRIDE1 * x1 + (X_STRIDE1 * r1) + (Y_STRIDE1 * y) + (Z_STRIDE1 * z)), + None) + + tmp9 = tmp8.reshape(BLOCK2, BLOCK1_SUB) + tmp7.reshape(BLOCK1_SUB, BLOCK2).permute(1, 0) + tl.store(out_ptr0 + (R_STRIDE1 * x1 + (X_STRIDE1 * r1) + (Y_STRIDE1 * y) + (Z_STRIDE1 * z)), tmp9, + None) + + +def foo_triton_wrapper(a, b, c): + NBLOCKS = 1 + BLOCK1 = a.shape[2] // NBLOCKS + BLOCK1_SUB = 64 + BLOCK2 = 64 + + value = torch.empty_strided((c.shape[0], c.shape[1], c.shape[2], c.shape[3]), + (c.stride()[0], c.stride()[1], c.stride()[2], c.stride()[3]), dtype=torch.float32).npu() + + triton_foo[NBLOCKS, 1, 1](a, b, c, value, BLOCK1, BLOCK1_SUB, BLOCK2, + a.shape[0], a.shape[1], a.shape[2], a.shape[3], + a.stride()[0], a.stride()[1], a.stride()[2], a.stride()[3], + c.stride()[0], c.stride()[1], c.stride()[2], c.stride()[3],) + return value + +def test_npu_indexing(): + Z, Y, X, R = (1, 1, 64, 64) + a = torch.randn((Z, Y, X, R), dtype=torch.float32).npu() + b = torch.randn((Z, Y, X, R), dtype=torch.float32).npu() + c = torch.randn((Z, Y, R, X), dtype=torch.float32).npu() + r = foo_triton_wrapper(a, b, c) + r1 = foo(a, b, c) + print(r[0, 0, 0:8, 0:8]) + print(r1[0, 0, 0:8, 0:8]) + torch.testing.assert_close(r, r1) diff --git a/third_party/ascend/examples/pytest_ut/test_npu_indexing2.py b/third_party/ascend/examples/pytest_ut/test_npu_indexing2.py new file mode 100644 index 000000000..90d10ad24 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_npu_indexing2.py @@ -0,0 +1,67 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import time + + +def foo(a, b, c): + y = a + b + c + y = y.sum(dim = 1) + return y + +@triton.jit +def triton_codegen2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr, RBLOCK : tl.constexpr): + ynumel = 8 + rnumel = 2048 + xnumel = 1024 + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + base2 = tl.arange(0, RBLOCK) + loops2: tl.constexpr = rnumel // RBLOCK + for y in range(ynumel): + y0 = y + for loop1 in range(loops1): + x = offset + (loop1 * XBLOCK_SUB) + base1 + x1 = offset + (loop1 * XBLOCK_SUB) + base1[None, :] + _tmp6 = tl.full([XBLOCK_SUB, RBLOCK], 0, tl.float32) + for loop2 in range(loops2): + r2 = loop2 * RBLOCK + base2[:, None] + tmp0 = tl.load(in_ptr0 + (x1 + (1024*r2) + (2097152*y0)), None, eviction_policy='evict_last') + tmp1 = tl.load(in_ptr1 + (x1 + (1024*r2) + (2097152*y0)), None, eviction_policy='evict_last') + tmp3 = tl.load(in_ptr2 + (x1 + (1024*r2) + (2097152*y0)), None, eviction_policy='evict_last') + tmp2 = tmp0 + tmp1 + tmp4 = tmp2 + tmp3 + tmp5 = tl.reshape(tmp4, [RBLOCK, XBLOCK_SUB]) + tmp7 = _tmp6 + tmp5 + _tmp6 = tmp7 + tmp6 = tl.sum(_tmp6, 0).reshape(XBLOCK_SUB) + + tl.store(out_ptr0 + (x + (1024*y0)), tmp6, None) + +def foo_triton_wrapper(a, b, c): + NBLOCKS = 8 + BLOCK1 = a.shape[2] // NBLOCKS + BLOCK1_SUB = 64 + BLOCK2 = 64 + + value = torch.empty_strided((c.shape[0], c.shape[2]), + ( c.shape[2], 1), dtype=torch.float32).npu() + + triton_codegen2[NBLOCKS, 1, 1](a, b, c, value, BLOCK1, BLOCK1_SUB, BLOCK2) + + return value + + +def test_npu_indexing2(): + + Y, X, R = ( 8, 2048, 1024) + a = torch.randn(( Y, X, R), dtype=torch.float32).npu() + b = torch.randn(( Y, X, R), dtype=torch.float32).npu() + c = torch.randn(( Y, X, R), dtype=torch.float32).npu() + r = foo_triton_wrapper(a, b, c) + r1 = foo(a, b, c) + print(r[ 0:8, 0:8, ]) + print(r1[ 0:8, 0:8]) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) diff --git a/third_party/ascend/examples/pytest_ut/test_or.py b/third_party/ascend/examples/pytest_ut/test_or.py new file mode 100644 index 000000000..94ce4ee11 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_or.py @@ -0,0 +1,41 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common + + +def torch_or(x0, x1): + res = x0 | x1 + return res + + +@triton.jit +def triton_or(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x_index = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + x_index, None) + tmp1 = tl.load(in_ptr1 + x_index, None) + tmp2 = tmp0 | tmp1 + tl.store(out_ptr0 + x_index, tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['int32', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_or(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_or(x0, x1) + # triton结果 + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_or[ncore, 1, 1](x0, x1, triton_res, xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_permute.py b/third_party/ascend/examples/pytest_ut/test_permute.py new file mode 100644 index 000000000..c5f773219 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_permute.py @@ -0,0 +1,72 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import time + +@triton.jit +def triton_foo(in_ptr0, in_ptr1, in_ptr2, out_ptr0, BLOCK1: tl.constexpr, BLOCK1_SUB: tl.constexpr, + BLOCK2: tl.constexpr, + X: tl.constexpr, Y: tl.constexpr, Z: tl.constexpr, R: tl.constexpr, + Z_STRIDE: tl.constexpr, Y_STRIDE: tl.constexpr, X_STRIDE: tl.constexpr, R_STRIDE: tl.constexpr, + X_STRIDE1: tl.constexpr, Y_STRIDE1: tl.constexpr, Z_STRIDE1: tl.constexpr, R_STRIDE1: tl.constexpr, + ): + offset: tl.constexpr = tl.program_id(0) * BLOCK1 + base1 = tl.arange(0, BLOCK1_SUB) + base2 = tl.arange(0, BLOCK2) + nsub: tl.constexpr = BLOCK1 // BLOCK1_SUB + # loops1 : tl.constexpr = nsub * Y * Z + loops1: tl.constexpr = nsub + loops2: tl.constexpr = R // BLOCK2 + + for z in range(Z): + for y in range(Y): + for loop1 in range(loops1): + off1 = loop1 + x = offset + (off1 * BLOCK1_SUB) + base1[:, None] + x1 = offset + (off1 * BLOCK1_SUB) + base1[None, :] + + for loop2 in range(loops2): + r = loop2 * BLOCK2 + base2[None, :] + r1 = loop2 * BLOCK2 + base2[:, None] + tmp0 = tl.load(in_ptr0 + ( (R_STRIDE * r) + (X_STRIDE * x) + (Y_STRIDE * y) + (Z_STRIDE * z)), None) + tmp1 = tl.load(in_ptr1 + ( (R_STRIDE * r) + (X_STRIDE * x) + (Y_STRIDE * y) + (Z_STRIDE * z)), None) + tmp2 = tmp0 + tmp1 + + tmp8 = tl.load(in_ptr2 + (R_STRIDE1 * r + X_STRIDE1 * x + (Y_STRIDE1 * y) + (Z_STRIDE1 * z)), None) + tmp9 = tmp8 + tmp2 + tl.store(out_ptr0 + (R_STRIDE1 * r + X_STRIDE1 * x + (Y_STRIDE1 * y) + (Z_STRIDE1 * z)), tmp9, + None) + +def foo_triton_wrapper(a, b, c): + NBLOCKS = 32 if c.shape[0] >=256 else 1 + BLOCK1 = c.shape[0] // NBLOCKS + BLOCK1_SUB = BLOCK1 if BLOCK1 < 64 else 64 + BLOCK2 = c.shape[3] if c.shape[3] < 64 else 64 + + value = torch.empty_strided((c.shape[0], c.shape[1], c.shape[2], c.shape[3]), + (c.stride()[0], c.stride()[1], c.stride()[2], c.stride()[3]), dtype=torch.float32).npu() + + triton_foo[NBLOCKS, 1, 1](a, b, c, value, BLOCK1, BLOCK1_SUB, BLOCK2, + c.shape[0], c.shape[1], c.shape[2], c.shape[3], + a.stride()[0], a.stride()[1], a.stride()[2], a.stride()[3], + c.stride()[0], c.stride()[1], c.stride()[2], c.stride()[3],) + return value + +def foo(a, b, c): + y = a + b + y = c + y.permute(2, 1, 0, 3) + return y + + +def test_permute_handwritten(): + + Z, Y, X, R = (1, 12, 4096, 8) + a = torch.randn((Z, Y, X, R), dtype=torch.float32).npu() + b = torch.randn((Z, Y, X, R), dtype=torch.float32).npu() + c = torch.randn((X, Y, Z, R), dtype=torch.float32).npu() + r = foo_triton_wrapper(a, b, c) + r1 = foo(a, b, c) + print(r[0, 0, 0:8, 0:8]) + print(r1[0, 0, 0:8, 0:8]) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_permute_full.py b/third_party/ascend/examples/pytest_ut/test_permute_full.py new file mode 100644 index 000000000..f10a09f68 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_permute_full.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + +@triton.jit +def fn_npu_021(output_ptr, x_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + # XB,YB,1 + X = tl.load(x_ptr + idx) + + ret=tl.permute(X,(0,2,1)) + + oidx=xidx[:,None,None]*YB*ZB+zidx[None,:,None]*YB+yidx[None,None,:] + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_102(output_ptr, x_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + # XB,YB,1 + X = tl.load(x_ptr + idx) + + ret=tl.permute(X,(1,0,2)) + + oidx=yidx[:,None,None]*XB*ZB+xidx[None,:,None]*ZB+zidx[None,None,:] + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_210(output_ptr, x_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + # XB,YB,1 + X = tl.load(x_ptr + idx) + + ret=tl.permute(X,(2,1,0)) + + oidx=zidx[:,None,None]*YB*XB+yidx[None,:,None]*XB+xidx[None,None,:] + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_201(output_ptr, x_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + # XB,YB,1 + X = tl.load(x_ptr + idx) + + ret=tl.permute(X,(2,0,1)) + + oidx=zidx[:,None,None]*YB*XB+xidx[None,:,None]*YB+yidx[None,None,:] + + tl.store(output_ptr + oidx, ret) + +@triton.jit +def fn_npu_120(output_ptr, x_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + # XB,YB,1 + X = tl.load(x_ptr + idx) + + ret=tl.permute(X,(1,2,0)) + + oidx=yidx[:,None,None]*ZB*XB+zidx[None,:,None]*XB+xidx[None,None,:] + + tl.store(output_ptr + oidx, ret) + + +@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', + [ + # ['float32',eval('torch.float32'),2,4,3], + ['float32',eval('torch.float32'),2,4,8], + # ['float32',eval('torch.float32'),2,4,37], + ['float32',eval('torch.float32'),2,4,64], + # ['float32',eval('torch.float32'),2,4,781], + + # ['float16',eval('torch.float16'),2,4,3], + ['float16',eval('torch.float16'),2,4,8], + # ['float16',eval('torch.float16'),2,4,37], + ['float16',eval('torch.float16'),2,4,64], + # ['float16',eval('torch.float16'),2,4,781], + + # ['int8',eval('torch.int8'),2,4,3], + ['int8',eval('torch.int8'),2,4,8], + # ['int8',eval('torch.int8'),2,4,37], + ['int8',eval('torch.int8'),2,4,64], + # ['int8',eval('torch.int8'),2,4,781], + ] + ) +def test_permute(para_type,data_type,XB,YB,ZB): + + x = torch.randint(low=0,high=2,size=(XB,YB,ZB),dtype=data_type).npu() + + output = torch.randint(1, (XB,ZB,YB), dtype=data_type).npu() + torch_021 = torch.permute(x,(0,2,1)) + fn_npu_021[1,1,1](output, x, XB, YB, ZB) + torch.testing.assert_close(output,torch_021) + + print(" test permute 021 passed") + + output = torch.randint(1, (YB,XB,ZB), dtype=data_type).npu() + torch_102 = torch.permute(x,(1,0,2)) + fn_npu_102[1,1,1](output, x, XB, YB, ZB) + torch.testing.assert_close(output,torch_102) + + print(" test permute 102 passed") + + output = torch.randint(1, (ZB,XB,YB), dtype=data_type).npu() + torch_201 = torch.permute(x, (2,0,1)) + fn_npu_201[1,1,1](output, x, XB, YB, ZB) + torch.testing.assert_close(output, torch_201) + + print(" test permute 201 passed") + + output = torch.randint(1, (ZB,YB,XB), dtype=data_type).npu() + torch_210 = torch.permute(x, (2,1,0)) + fn_npu_210[1,1,1](output, x, XB, YB, ZB) + torch.testing.assert_close(output, torch_210) + + print(" test permute 210 passed") + + output = torch.randint(1, (YB,ZB,XB), dtype=data_type).npu() + torch_120 = torch.permute(x, (1,2,0)) + fn_npu_120[1,1,1](output, x, XB, YB, ZB) + torch.testing.assert_close(output, torch_120) + + print(" test permute 120 passed") \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_permute_reshape.py b/third_party/ascend/examples/pytest_ut/test_permute_reshape.py new file mode 100644 index 000000000..3f8e4ae41 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_permute_reshape.py @@ -0,0 +1,59 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import time + + +@triton.jit +def triton_foo(in_ptr0, in_ptr1, in_ptr2, out_ptr0, BLOCK1: tl.constexpr, BLOCK1_SUB: tl.constexpr, + BLOCK2: tl.constexpr, S: tl.constexpr, N: tl.constexpr, D: tl.constexpr, + ): + offset: tl.constexpr = tl.program_id(0) * BLOCK1 + base1 = tl.arange(0, BLOCK1_SUB) + base2 = tl.arange(0, BLOCK2) + loops1: tl.constexpr = BLOCK1 // BLOCK1_SUB + loops2: tl.constexpr = D // BLOCK2 + + for loop1 in range(loops1): + off1 = loop1 + s = offset + (off1 * BLOCK1_SUB) + base1[:, None] + for n in range(N) : + for loop2 in range(loops2): + d = loop2 * BLOCK2 + base2[None, :] + tmp0 = tl.load(in_ptr0 + ((32768*n) + (8*s) + d), None) + tmp1 = tl.load(in_ptr1 + ((32768*n) + (8*s) + d), None) + tmp2 = tmp0 + tmp1 + + tmp3 = tl.load(in_ptr2 + ((8*n) + d + (96*s)), None) + tmp9 = tmp3 + tmp2 + tl.store(out_ptr0 + ((8*n) + d + (96*s)), tmp9, None) + +def foo_triton_wrapper(a, b, c): + NBLOCKS = 32 if a.shape[2] >=256 else 1 + BLOCK1 = a.shape[2] // NBLOCKS + BLOCK1_SUB = BLOCK1 if BLOCK1 < 64 else 64 + BLOCK2 = a.shape[3] if a.shape[3] < 64 else 64 + + value = torch.empty_strided((c.shape[0], c.shape[1], c.shape[2]), + (c.stride()[0], c.stride()[1], c.stride()[2]), dtype=torch.float32).npu() + triton_foo[NBLOCKS, 1, 1](a, b, c, value, BLOCK1, BLOCK1_SUB, BLOCK2, a.shape[2], a.shape[1], a.shape[3]) + + return value + +def foo(a, b, c): + B, N, S, D = (1, 12, 4096, 8) + y = a + b + y = c + y.permute(2, 0, 1, 3 ).reshape(S,B, N*D) + return y + +def test_permute_reshape(): + B, N, S, D = (1, 12, 4096, 8) + a = torch.randn((B, N, S, D), dtype=torch.float32).npu() + b = torch.randn((B, N, S, D), dtype=torch.float32).npu() + c = torch.randn((S, B, N*D), dtype=torch.float32).npu() + r = foo_triton_wrapper(a, b, c) + r1 = foo(a, b, c) + print(r[0:8, 0, 0:8]) + print(r1[0:8,0, 0:8]) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_pointer_type.py b/third_party/ascend/examples/pytest_ut/test_pointer_type.py new file mode 100644 index 000000000..c40198772 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_pointer_type.py @@ -0,0 +1,27 @@ +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + +@triton.jit +def kernel(ans_ptr, x_ptr): + val = tl.load(x_ptr) + output_ptr = tl.load(ans_ptr) + output_ptr = output_ptr.to(tl.pointer_type(val.dtype)) + tl.store(output_ptr, val) + +@pytest.mark.parametrize("literal, dtype_str",[[0, eval('torch.int8')], [0, eval('torch.int16')], + [0, eval('torch.int32')], [0, eval('torch.int64')], + [0, eval('torch.float16')], [0, eval('torch.float32')]]) +def test_pointer_type(literal, dtype_str): + x = torch.randint(low=0, high=5, size=(1,), dtype=dtype_str).npu() + output = torch.zeros((1,), dtype=dtype_str).npu() + ans = [] + ans.append(output.data_ptr()) + ans_tensor = torch.tensor(ans).npu() + kernel[(1,)](ans_tensor, x) + assert torch.isclose(x, output) + print("Pointer type convert successful") \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_pow.py b/third_party/ascend/examples/pytest_ut/test_pow.py new file mode 100644 index 000000000..06ec9f70f --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_pow.py @@ -0,0 +1,111 @@ +import triton +import triton.language as tl +from triton.language.extra.ascend.libdevice import pow +import torch +import torch_npu +import pytest +import test_common + +types = [ + "float32", + "float16", + "bfloat16", + "int64", + "int32", + "int16", + "int8", +] + +shapes = [ + # 3, + # 32, + 37, + # 256, + # 781, +] + +@pytest.mark.skip(reason="waiting for bishengir-compile to support") +@pytest.mark.parametrize("sigtype", types) +@pytest.mark.parametrize("N", shapes) +def test_pow_vv(sigtype, N): + + def torch_func(x0, x1): + res = torch.pow(x0, x1) + return res + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, in_ptr1, N: tl.constexpr): + idx = tl.arange(0, N) + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = pow(x0, x1) + tl.store(out_ptr0 + idx, ret) + + def triton_func(x0, x1, N): + out = torch.empty_like(x0) + triton_kernel[1, 1, 1](out, x0, x1, N) + return out + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + x1 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + + triton_cal = triton_func(x0, x1, N) + torch_ref = torch_func(x0, x1) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + +@pytest.mark.skip(reason="waiting for bishengir-compile to support") +@pytest.mark.parametrize("sigtype", types) +@pytest.mark.parametrize("N", shapes) +def test_pow_vs_dynamic(sigtype, N): + + def torch_func(x0, x1): + res = torch.pow(x0, x1) + return res + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, in_ptr1, N: tl.constexpr): + idx = tl.arange(0, N) + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1) + ret = pow(x0, x1) + tl.store(out_ptr0 + idx, ret) + + def triton_func(x0, x1, N): + out = torch.empty_like(x0) + triton_kernel[1, 1, 1](out, x0, x1, N) + return out + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + x1 = test_common.generate_tensor(shape=(1,), dtype=sigtype).npu() + + triton_cal = triton_func(x0, x1, N) + torch_ref = torch_func(x0, x1) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + +@pytest.mark.skip(reason="waiting for bishengir-compile to support") +@pytest.mark.parametrize("sigtype", types) +@pytest.mark.parametrize("N", shapes) +def test_pow_vs_const(sigtype, N): + + def torch_func(x0, x1): + res = torch.pow(x0, x1) + return res + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, x1: tl.constexpr, N: tl.constexpr): + idx = tl.arange(0, N) + x0 = tl.load(in_ptr0 + idx) + ret = pow(x0, x1) + tl.store(out_ptr0 + idx, ret) + + def triton_func(x0, x1, N): + out = torch.empty_like(x0) + triton_kernel[1, 1, 1](out, x0, x1, N) + return out + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype).npu() + x1 = test_common.generate_tensor(shape=(1,), dtype=sigtype).item() + + triton_cal = triton_func(x0, x1, N) + torch_ref = torch_func(x0, x1) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_precise_div.py b/third_party/ascend/examples/pytest_ut/test_precise_div.py new file mode 100644 index 000000000..a643589cc --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_precise_div.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl + +import torch +import torch_npu +import test_common + + +def torch_divRn(x0, x1): + return x0 / x1 + +@triton.jit +def triton_divRn(in_ptr0, in_ptr1, out_ptr0, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tl.div_rn(tmp0, tmp1) + tl.store(out_ptr0 + (x0), tmp2, None) + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 32, 2048, 64], + ]) + +def test_divRn(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype) + x2 = x1.masked_fill(x1 == 0, 1) + x2 = x2.npu() + y_ref = torch_divRn(x0, x2) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + triton_divRn[ncore, 1, 1](x0, x2, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_precise_sqrt.py b/third_party/ascend/examples/pytest_ut/test_precise_sqrt.py new file mode 100644 index 000000000..5b3f04d40 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_precise_sqrt.py @@ -0,0 +1,31 @@ +import torch +import torch_npu +import triton +import triton.language as tl + +torch.set_printoptions(precision=10) +@triton.jit +def sqrtrn_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr ): + id = tl.program_id(axis=0) + start = id * BLOCK_SIZE + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + + output = x + tl.sqrt_rn(y) + tl.store(output_ptr + offsets, output, mask=mask) + +def sqrtrn(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(y) + grid = lambda meta: (triton.cdiv(output.numel(), meta['BLOCK_SIZE']),) + sqrtrn_kernel[grid](x, y, output, output.numel(), BLOCK_SIZE=512) + return output + +def test_sqrtrn_fp32(): + size = 10240 + x = torch.abs(torch.randn(size, device='npu', dtype=torch.float32)) + y = torch.abs(torch.randn(size, device='npu', dtype=torch.float32)) + ref = x + torch.sqrt(y) + cal = sqrtrn(x, y) + torch.testing.assert_close(cal, ref, rtol=1e-06, atol=1e-06, equal_nan=True) diff --git a/third_party/ascend/examples/pytest_ut/test_ptr_add.py b/third_party/ascend/examples/pytest_ut/test_ptr_add.py new file mode 100755 index 000000000..0eb4ad8b3 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_ptr_add.py @@ -0,0 +1,62 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest + + +def ptr_add(device): + + @triton.jit + def wrap_stacked(a_ptr, c_ptr, + M: tl.constexpr, + N: tl.constexpr, + stride_am: tl.constexpr, + stride_an: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + offs_am = (2 + tl.arange(0, 4)) % M + offs_an = tl.arange(0, 4) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + + offs_an[None, :] * stride_an) + + offs_cm = tl.arange(0, 4) + offs_cn = tl.arange(0, 4) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[ + None, :] + for k in range(0, 2): + a = tl.load(a_ptrs) + tl.store(c_ptrs, a) + a_ptrs += BLOCK_SIZE_K * stride_an + c_ptrs += BLOCK_SIZE_K * stride_an + + M = 4 + N = 8 + AA = torch.arange(0, M * N, dtype=torch.float32).npu() + A = AA.reshape((M, N)) + out = torch.full((M, N), 88888, dtype=torch.float32).npu() + grid = lambda meta: (1, ) + + wrap_stacked[grid](A, + out, + M, + N, + A.stride(0), + A.stride(1), + out.stride(0), + out.stride(1), + BLOCK_SIZE_K=4) + + # Expected output copied from running triton on NPU + expected_out = torch.tensor( + [[16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], + [0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15]], + ).npu() + + assert torch.equal(expected_out.int(), out.int()) + + +def test_ptr_add(): + device = 'npu' + ptr_add(device) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_rand.py b/third_party/ascend/examples/pytest_ut/test_rand.py new file mode 100644 index 000000000..f2afe0f64 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_rand.py @@ -0,0 +1,63 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common +import math + + +@triton.jit +def kernel_rand(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): + block_offset = tl.program_id(0) * XBLOCK + block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset + for inner_idx in range(block_size): + global_offset = block_offset + inner_idx + rand_vals = tl.rand(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 + tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 + +@triton.jit +def kernel_randn(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): + block_offset = tl.program_id(0) * XBLOCK + block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset + for inner_idx in range(block_size): + global_offset = block_offset + inner_idx + rand_vals = tl.randn(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 + tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 + +@triton.jit +def kernel_randint(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): + block_offset = tl.program_id(0) * XBLOCK + block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset + for inner_idx in range(block_size): + global_offset = block_offset + inner_idx + rand_vals = tl.randint(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 + tl.store(x_ptr + global_offset, rand_vals) # 存储随机数 + +@triton.jit +def kernel_randint4x(x_ptr, n_rounds: tl.constexpr, N: tl.constexpr, XBLOCK: tl.constexpr): + block_offset = tl.program_id(0) * XBLOCK + indices = tl.arange(0, 4) + block_size = XBLOCK if block_offset + XBLOCK <= N else N - block_offset + for inner_idx in range(0, block_size, step=4): + global_offset = block_offset + inner_idx + rand_vals = tl.randint4x(5, 10 + global_offset, n_rounds) # 对每个索引生成一个随机数 + mask = (global_offset + indices) < N + tl.store(x_ptr + global_offset + indices, rand_vals, mask) # 存储随机数 + +shapes = [(1,3)] + +@pytest.mark.parametrize('shape', shapes) +def test_case(shape): + y_calf = torch.zeros(shape, dtype=eval('torch.float32')).npu() + + numel = y_calf.numel() + ncore = 1 if numel < 32 else 32 + xblock = math.ceil(numel / ncore) + + kernel_rand[ncore, 1, 1](y_calf, 10, numel, xblock) + kernel_randn[ncore, 1, 1](y_calf, 10, numel, xblock) + + y_cali = torch.zeros(shape, dtype=eval('torch.int32')).npu() + + kernel_randint[ncore, 1, 1](y_cali, 10, numel, xblock) + kernel_randint4x[ncore, 1, 1](y_cali, 10, numel, xblock) diff --git a/third_party/ascend/examples/pytest_ut/test_range.py b/third_party/ascend/examples/pytest_ut/test_range.py new file mode 100644 index 000000000..6a8be5f2e --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_range.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common + + +@triton.jit +def triton_add(in_ptr0, in_ptr1, out_ptr0, L : tl.constexpr, M : tl.constexpr, N : tl.constexpr): + lblk_idx = tl.arange(0,L) + mblk_idx = tl.arange(0,M) + nblk_idx = tl.arange(0,N) + idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x0=tl.load(in_ptr0+idx) + x1=tl.load(in_ptr1+idx) + ret = x0 + x1 + for i in tl.range(2,5,2): + ret = ret + x1 + for i in tl.static_range(2,10,3): + ret = ret + x0 + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0+odx, ret) + +testlist = [ + (3,5,8), +] + +def get_torch_typename(dtype): + if dtype == 'float32': + tyname = torch.float32 + elif dtype == 'int32': + tyname = torch.int32 + elif dtype == 'int64': + tyname = torch.int64 + elif dtype == 'float16': + tyname = torch.float16 + elif dtype == 'bfloat16': + tyname = torch.bfloat16 + elif dtype == 'int16': + tyname = torch.int16 + elif dtype == 'int8': + tyname = torch.int8 + elif dtype == 'bool': + tyname = torch.bool + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + return tyname + +typelist = ['int8','int16','int32','int64'] + +@pytest.mark.parametrize('L, M, N',testlist) +@pytest.mark.parametrize('sigtype',typelist) +def test_add(sigtype, L, M, N): + dtype = get_torch_typename(sigtype) + shape = (L, M, N) + x0 = test_common.generate_tensor(shape = (L, M, N),dtype = sigtype).npu() + x1 = test_common.generate_tensor(shape = (L, M, N),dtype = sigtype).npu() + y_ref = x0 + x1 + x1 + x1 + x0 + x0 + x0 + output = torch.zeros(shape, dtype=dtype).npu() + triton_add[1, 1, 1](x0, x1, output, L, M, N) + test_common.validate_cmp(sigtype, output, y_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_ravel.py b/third_party/ascend/examples/pytest_ut/test_ravel.py new file mode 100644 index 000000000..881d35c43 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_ravel.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + +@triton.jit +def fn_npu_(output_ptr, x_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + X = tl.load(x_ptr+idx) + + ret = tl.ravel(X) + + oidx=tl.arange(0,XB*YB*ZB) + tl.store(output_ptr+oidx,ret) + +testlist = [ + ('float32',torch.float32,2,256,16), + ('float32',torch.float32,8,8,4), + + ('float16',torch.float16,2,256,16), + ('float16',torch.float16,8,8,4), + + ('int8',torch.int8,2,256,16), + ('int8',torch.int8,8,8,4), +] + +@pytest.mark.parametrize('sigtype, dtype, XB, YB, ZB',testlist) +def test_ravel(sigtype, dtype, XB, YB, ZB): + + x = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=dtype).npu() + ans = torch.ravel(x) + + print(ans[0:16]) + + output = torch.randint(1, (XB*YB*ZB,), dtype=dtype).npu() + + fn_npu_[1,1,1](output,x, XB, YB, ZB) + + print(output[0:16]) + + test_common.validate_cmp(sigtype,output,ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_reduce_count_vector.py b/third_party/ascend/examples/pytest_ut/test_reduce_count_vector.py new file mode 100644 index 000000000..1a262c0e7 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_reduce_count_vector.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time +import test_common + +import torch +import torch_npu + + +def standard_count(x0, cmp_val, dim): + res = (x0 == cmp_val).sum(dim=dim) + return res + + +def standard_gt(x0, cmp_val, dim): + res = (x0 > cmp_val).sum(dim=dim) + return res + + +def standard_lt(x0, cmp_val, dim): + res = (x0 < cmp_val).sum(dim=dim) + return res + + +@triton.jit +def triton_count(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, N) + x = tl.load(in_ptr0 + idx_block) + + tmp3 = (x == cmp_val) + # tmp3 bool -> tl.float32 + tmp4 = tmp3.to(tl.float32) + res = tl.sum(tmp4, dim) + + tl.store(out_ptr0 + idx_block, res) + + +@triton.jit +def triton_gt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, N) + x = tl.load(in_ptr0 + idx_block) + + tmp3 = (x > cmp_val) + # tmp3 bool -> tl.float32 + tmp4 = tmp3.to(tl.float32) + res = tl.sum(tmp4, dim) + + tl.store(out_ptr0 + idx_block, res) + + +@triton.jit +def triton_lt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, N) + x = tl.load(in_ptr0 + idx_block) + + tmp3 = (x < cmp_val) + # tmp3 bool -> tl.float32 + tmp4 = tmp3.to(tl.float32) + res = tl.sum(tmp4, dim) + + tl.store(out_ptr0 + idx_block, res) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + (torch.bfloat16, 'bfloat16'), + (torch.int8, 'int8'), + (torch.int16, 'int16'), + (torch.int32, 'int32'), + (torch.int64, 'int64'), +] + +# if shape axis = 32/256 , then actual shape = axis/element_size() +shapes = [ + (32, 32), +] + +map_for_64_t = {37: 31} + +CPM_VAL_INT = 8 +CPM_VAL_FLOAT = 0.5 + +# TO BE FIXED with mask +ops = [ + ('counti', triton_count, standard_count, CPM_VAL_INT), + ('countf', triton_gt, standard_gt, CPM_VAL_FLOAT), + ('countf', triton_lt, standard_lt, CPM_VAL_FLOAT), +] + + +def judge_continue(opName, sigtype): + if opName == 'counti' and 'int' in sigtype: + return False + if opName == 'countf' and 'float' in sigtype: + return False + return True + + +@pytest.mark.parametrize('opName, tritonOp, standOp, cmp_val', ops) +@pytest.mark.parametrize('dtype, sigtype', types) +@pytest.mark.parametrize('N, NUMEL', shapes) +def test_reduce_count_vector(opName, tritonOp, standOp, cmp_val, dtype, sigtype, N, NUMEL): + if judge_continue(opName, sigtype): + return + torch.manual_seed(0) + torch_npu.npu.utils.set_device(0) + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == 'int64': + N = map_for_64_t[N] if N in map_for_64_t else N + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + ans = standOp(x0, cmp_val, 0) + x0 = x0.npu() + + output = torch.tensor(0, dtype=torch.float32).npu() + tritonOp[1, 1, 1](x0, output, cmp_val, dim=0, N=N, NUMEL=NUMEL, debug=True) + output = output.cpu().to(torch.int32) + # print(f'x0:{x0}\ntriton:{output}\ntorch:{ans}') + assert torch.equal(output, ans) + + +if __name__ == "__main__": + dtype = torch.float32 + sigtype = 'float32' + allshape = [(3, 32)] + for shape in allshape: + test_reduce_count_vector('countf', triton_lt, standard_lt, CPM_VAL_FLOAT, dtype, sigtype, shape[0], shape[1]) diff --git a/third_party/ascend/examples/pytest_ut/test_reduce_mean.py b/third_party/ascend/examples/pytest_ut/test_reduce_mean.py new file mode 100644 index 000000000..7e7201fc1 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_reduce_mean.py @@ -0,0 +1,64 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common +import numpy as np + +def numpy_mean_pr(x0, x1): + res = np.mean(x0, axis=-1) + x1 + return res + +@triton.jit +def triton_mean_pr(out_ptr0, in_ptr0, in_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr, RBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + rindex = tl.arange(0, RBLOCK)[None, :] + rmask = rindex < rnumel + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB) + xmask = xindex[:,None] < xnumel + x0 = xindex + r1 = rindex + tmp0 = tl.load(in_ptr0 + (r1 + (RBLOCK*x0[:, None])), xmask & rmask) + tmp4 = tl.load(in_ptr1 + (x0), xindex < xnumel) + tmp1 = tl.reshape(tmp0, [XBLOCK_SUB, RBLOCK]) + tmp3 = tl.sum(tmp1, 1) / RBLOCK + tmp5 = tmp3 + tmp4 + tl.store(out_ptr0 + (xindex), tmp5, None) + +@pytest.mark.parametrize('param_list', + [ + ['float32', (8, 8, 4), 8, 2], + ['float32', (8, 8, 64), 8, 2], + ['float32', (8, 8, 1024), 8, 2], + ['float16', (8, 8, 4), 8, 2], + ['float16', (8, 8, 64), 8, 2], + ['float16', (8, 8, 1024), 8, 2], + ['int8', (8, 8, 4), 8, 2], + ['int8', (8, 8, 64), 8, 2], + ['int8', (8, 8, 1024), 8, 2], + ] + ) + +def test_mean_pr(param_list): + dtype, shape, ncore, xblock_sub = param_list + import math + numel = math.prod(shape) + xblock = numel // shape[-1] // ncore + rblock = shape[-1] + assert (ncore * xblock * shape[-1] == numel) + xn1 = np.random.randn(shape[0], shape[1], shape[2]).astype(eval('np.' + dtype)) + xn2 = np.random.randn(shape[0], shape[1]).astype(eval('np.' + dtype)) + x0 = torch.tensor(xn1).npu() + x1 = torch.tensor(xn2).npu() + y_ref = numpy_mean_pr(xn1, xn2) + if dtype == 'int8': + y_cal = test_common.generate_tensor(shape[:-1], 'float32').npu() + else: + y_cal = test_common.generate_tensor(shape[:-1], dtype).npu() + triton_mean_pr[ncore, 1, 1](y_cal, x0, x1, x1.numel(), rblock, xblock, xblock_sub, rblock) + if dtype == 'int8': + torch.allclose(torch.tensor(y_ref.astype(np.float32)).npu(), y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) + else: + torch.allclose(torch.tensor(y_ref).npu(), y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_reduce_sum.py b/third_party/ascend/examples/pytest_ut/test_reduce_sum.py new file mode 100644 index 000000000..6cda6c155 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_reduce_sum.py @@ -0,0 +1,69 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import numpy as np + +import os +import sys +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__),'..')) +sys.path.append(parent_dir) +import test_common + +# PR: Pointiwise-Reduction pattern, reduction in last axis +def numpy_sum_pr(x0, x1): + res = np.sum(x0, axis=-1) + x1 + return res + +@triton.jit +def triton_sum_pr(out_ptr0, in_ptr0, in_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr, RBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + rindex = tl.arange(0, RBLOCK)[None, :] + rmask = rindex < rnumel + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB) + xmask = xindex[:,None] < xnumel + x0 = xindex + r1 = rindex + tmp0 = tl.load(in_ptr0 + (r1 + (RBLOCK*x0[:, None])), xmask & rmask) + tmp4 = tl.load(in_ptr1 + (x0), xindex < xnumel) + tmp1 = tl.reshape(tmp0, [XBLOCK_SUB, RBLOCK]) + tmp3 = tl.sum(tmp1, 1) + tmp5 = tmp3 + tmp4 + tl.store(out_ptr0 + (xindex), tmp5, None) + + +# fp16 use numpy +@pytest.mark.parametrize('param_list', + [ + # ['float32', (8, 8, 4), 8, 2], + ['float32', (8, 8, 64), 8, 2], + ['float32', (8, 8, 512), 8, 2], + ['float16', (8, 8, 4), 8, 2], + ['float16', (8, 8, 64), 8, 2], + ['float16', (8, 8, 512), 8, 2], + ['int8', (8, 8, 4), 8, 2], + ['int8', (8, 8, 64), 8, 2], + ['int8', (8, 8, 512), 8, 2], + ] + ) + +def test_sum_pr(param_list): + dtype, shape, ncore, xblock_sub = param_list + import math + numel = math.prod(shape) + xblock = numel // shape[-1] // ncore + rblock = shape[-1] + assert(ncore * xblock * shape[-1] == numel) + xn1 = np.random.randn(shape[0], shape[1], shape[2]).astype(eval('np.' + dtype)) + xn2 = np.random.randn(shape[0], shape[1]).astype(eval('np.' + dtype)) + x0 = torch.tensor(xn1).npu() + x1 = torch.tensor(xn2).npu() + if dtype == 'int8': + y_ref = numpy_sum_pr(xn1, xn2).astype(np.int8) + else: + y_ref = numpy_sum_pr(xn1, xn2) + y_cal = test_common.generate_tensor(shape[:-1], dtype).npu() + triton_sum_pr[ncore,1,1](y_cal, x0, x1, x1.numel(), rblock, xblock, xblock_sub, rblock) + test_common.validate_cmp(dtype, y_cal, torch.tensor(y_ref).npu()) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_relu.py b/third_party/ascend/examples/pytest_ut/test_relu.py new file mode 100644 index 000000000..bdc8261f3 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_relu.py @@ -0,0 +1,42 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common +import triton.language.extra.ascend.libdevice as libdevice + + +def torch_relu(x0, x1): + res = x0 + torch.relu(x1) + return res + + +@triton.jit +def triton_relu(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + x_index = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = x_index < xnumel + tmp0 = tl.load(in_ptr0 + x_index, xmask) + tmp1 = tl.load(in_ptr1 + x_index, xmask) + tmp2 = tmp0 + libdevice.relu(tmp1) + tl.store(out_ptr0 + x_index, tmp2, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_relu(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + torch_res = torch_relu(x0, x1) + # triton结果 + triton_res = test_common.generate_tensor(shape, dtype).npu() + triton_relu[ncore, 1, 1](x0, x1, triton_res, x0.numel(), xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_reshape.py b/third_party/ascend/examples/pytest_ut/test_reshape.py new file mode 100644 index 000000000..0ef74f069 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_reshape.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + +@triton.jit +def fn_npu_(output_ptr, x_ptr,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + X = tl.load(x_ptr+idx) + + ret = tl.reshape(X,(ZB,XB*YB)) + + oidx=tl.arange(0,ZB)[:,None]*XB*YB+tl.arange(0,XB*YB)[None,:] + + tl.store(output_ptr+oidx,ret) + +testlist = [ + ('float32',torch.float32,2,256,16), + ('float32',torch.float32,8,8,4), + + ('float16',torch.float16,2,256,16), + ('float16',torch.float16,8,8,4), + + ('int8',torch.int8,2,256,16), + ('int8',torch.int8,8,8,4), +] + +@pytest.mark.parametrize('sigtype, dtype, XB, YB, ZB',testlist) +def test_ravel(sigtype, dtype, XB, YB, ZB): + + x = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=dtype).npu() + ans = torch.reshape(x,(ZB,XB*YB)) + + print(ans[0,0:16]) + + output = torch.randint(1, (ZB,XB*YB), dtype=dtype).npu() + + fn_npu_[1,1,1](output, x, XB, YB, ZB) + print(output[0,0:16]) + + test_common.validate_cmp(sigtype,output,ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_rint.py b/third_party/ascend/examples/pytest_ut/test_rint.py new file mode 100644 index 000000000..74bab1439 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_rint.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common + + +def torch_rint(x0): + res = torch.round(x0) + return res + + +@triton.jit +def triton_rint(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.math.rint(tmp0) + tl.store(out_ptr0 + (x0), tmp1, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 32, 2048, 64], + ]) +def test_rint(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype) + y_ref = torch_rint(x0) + tyname = test_common.get_triton_sig_typename(dtype) + + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + x0 = x0.npu() + triton_rint[ncore, 1, 1](x0, y_cal, xblock, xblock_sub, debug=True) + y_ref = y_ref.npu() + test_common.validate_cmp_with_expection(dtype, y_cal, y_ref, True) diff --git a/third_party/ascend/examples/pytest_ut/test_rms_norm.py b/third_party/ascend/examples/pytest_ut/test_rms_norm.py new file mode 100644 index 000000000..495ce30ac --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_rms_norm.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023 by Microsoft Corporation. +# Licensed under the MIT license. + +# Triton RMSNorm kernels from LightLLM project. + +# Reference: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py + +import torch +import triton +import triton.language as tl +import torch_npu + +@triton.jit +def _rms_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + y = x_hat * w + # Write output + tl.store(Y + cols, y.to(tl.float16), mask=mask) + +# have to change the block_size +@torch.inference_mode() +def rms_norm(x, weight, eps, out=None): + # allocate output, tl.store save y in tl.float16 + y = torch.empty_like(x, dtype=torch.float16) if out is None else out + # reshape input data into 2D tensor + x_arg = x.view(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 + num_warps = 8 + # enqueue kernel + kernel = _rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + return y, kernel + + +def _rms_norm(shape, datatype): + x = torch.randn(shape[0], shape[1], dtype=datatype, device="npu") + weight = torch.randn(shape[1], dtype=datatype, device="npu") + y, kernel = rms_norm(x, weight, eps=1e-5) + eps1 = 1e-5 + if datatype == torch.bfloat16 or datatype == torch.float16: + x = x.to(torch.float32) + rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps1) # 计算均方根 + x_norm = x / rms # 标准化 + y_ref = weight * x_norm + y_ref = y_ref.to(torch.float16) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +def test_cases(): + _rms_norm((16, 1024), torch.float16) + _rms_norm((16, 1024), torch.float32) + _rms_norm((16, 1024), torch.bfloat16) + _rms_norm((128, 3), torch.bfloat16) + _rms_norm((128, 16), torch.bfloat16) + _rms_norm((128, 37), torch.bfloat16) + _rms_norm((128, 64), torch.bfloat16) + _rms_norm((128, 781), torch.bfloat16) + _rms_norm((128, 781), torch.bfloat16) + _rms_norm((16, 1024), torch.float16) + _rms_norm((16, 1024), torch.float32) + _rms_norm((16, 1024), torch.bfloat16) + + _rms_norm((128, 128), torch.float16) + _rms_norm((128, 128), torch.float32) + _rms_norm((128, 128), torch.bfloat16) + + _rms_norm((1, 128), torch.float16) + _rms_norm((1, 128), torch.float32) + _rms_norm((1, 128), torch.bfloat16) + + _rms_norm((65535, 128), torch.float16) + _rms_norm((65535, 128), torch.float32) + _rms_norm((65535, 128), torch.bfloat16) + + _rms_norm((128, 3), torch.float16) + _rms_norm((128, 3), torch.float32) + _rms_norm((128, 3), torch.bfloat16) + + _rms_norm((128, 16), torch.float16) + _rms_norm((128, 16), torch.float32) + _rms_norm((128, 16), torch.bfloat16) + + _rms_norm((128, 37), torch.float16) + _rms_norm((128, 37), torch.float32) + _rms_norm((128, 37), torch.bfloat16) + + _rms_norm((128, 64), torch.float16) + _rms_norm((128, 64), torch.float32) + _rms_norm((128, 64), torch.bfloat16) + + _rms_norm((128, 781), torch.float16) + _rms_norm((128, 781), torch.float32) + _rms_norm((128, 781), torch.bfloat16) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_rotary_embedding.py b/third_party/ascend/examples/pytest_ut/test_rotary_embedding.py new file mode 100644 index 000000000..e8c726dde --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_rotary_embedding.py @@ -0,0 +1,167 @@ +# Copyright (c) 2023 by Microsoft Corporation. +# Licensed under the MIT license. + +"""Rotary embedding kernel implemented by Triton. + +GPT-NeoX style +""" + +import torch +import torch_npu +import triton +import triton.language as tl + + +@triton.jit +def rotary_embedding_kernel( + state, # [num_tokens, head_num, head_dim] + cos, # [num_tokens, 1, head_dim // 2] + sin, # [num_tokens, 1, head_dim // 2] + stride_state_n, + stride_state_h, + stride_state_d, + stride_cos_n, + stride_cos_d, + # stride_sin_n, + # stride_sin_d, + num_tokens, + num_heads, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_index = tl.program_id(0) + token_range = token_index * BLOCK_N + tl.arange(0, BLOCK_N) + head_index = tl.program_id(1) + head_range = head_index * BLOCK_H + tl.arange(0, BLOCK_H) + + dim_range_x = tl.arange(0, BLOCK_D // 2) + dim_range_y = tl.arange(BLOCK_D // 2, BLOCK_D) + + state_x_offset = ( + token_range[:, None, None] * stride_state_n + + head_range[None, :, None] * stride_state_h + + dim_range_x[None, None, :] * stride_state_d + ) + state_y_offset = ( + token_range[:, None, None] * stride_state_n + + head_range[None, :, None] * stride_state_h + + dim_range_y[None, None, :] * stride_state_d + ) + + cos_sim_offset = ( + token_range[:, None, None] * stride_cos_n + + dim_range_x[None, None, :] * stride_cos_d + ) + + state_x = tl.load( + state + state_x_offset, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + other=0.0, + ) + state_y = tl.load( + state + state_y_offset, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + other=0.0, + ) + + cos_loaded = tl.load( + cos + cos_sim_offset, + mask=token_range[:, None, None] < num_tokens, + other=0.0, + ) + sin_loaded = tl.load( + sin + cos_sim_offset, + mask=token_range[:, None, None] < num_tokens, + other=0.0, + ) + + out_x = state_x * cos_loaded - state_y * sin_loaded + out_y = state_x * sin_loaded + state_y * cos_loaded + + tl.store( + state + state_x_offset, + out_x, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + ) + tl.store( + state + state_y_offset, + out_y, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + ) + + +@torch.inference_mode() +def rotary_embedding(state, cos, sin): + num_tokens = state.shape[0] + num_heads = state.shape[1] + head_dim = state.shape[2] + + #BLOCK_N = 32 + BLOCK_N = 16 + BLOCK_H = 4 + grid = ( + triton.cdiv(num_tokens, BLOCK_N), + triton.cdiv(num_heads, BLOCK_H), + ) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + kernel = rotary_embedding_kernel[grid]( + state, + cos, + sin, + state.stride(0), + state.stride(1), + state.stride(2), + cos.stride(0), + cos.stride(2), + # sin.stride(0), + # sin.stride(2), + num_tokens, + num_heads, + BLOCK_N=BLOCK_N, + BLOCK_H=BLOCK_H, + BLOCK_D=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return + + +def torch_rotary_embedding(state, cos, sin): + _, _, dim = state.shape + state_x = state[:, :, 0 : dim // 2] + state_y = state[:, :, dim // 2 : dim] + out_x = state_x * cos - state_y * sin + out_y = state_x * sin + state_y * cos + return torch.cat((out_x, out_y), dim=-1) + + +def rotary_emb(tokens, heads, headdim, dtype): + tokens_num = tokens + num_heads = heads + head_dim = headdim + max_positions = 1024 + + # torch.float16 has floating point problem in Triton 2.0.0 + # But it works fine in Triton 2.1.0 + state = torch.randn((tokens_num, num_heads, head_dim), dtype=dtype, device="npu") + cos_shape = (tokens_num, 1, head_dim // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="npu") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="npu") + # forward pass + torch_result = torch_rotary_embedding(state, cos, sin) + rotary_embedding(state, cos, sin) + triton_result = state # state is modified in-place + torch.testing.assert_close(torch_result, triton_result, rtol=1e-3, atol=1e-3) + +def test_cases(): + rotary_emb(256, 96, 128, torch.float16) + rotary_emb(256, 96, 128, torch.float32) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_rotatry_gpt.py b/third_party/ascend/examples/pytest_ut/test_rotatry_gpt.py new file mode 100644 index 000000000..d31b61485 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_rotatry_gpt.py @@ -0,0 +1,178 @@ +# Copyright (c) 2023 by Microsoft Corporation. +# Licensed under the MIT license. + +""" +Rotary embedding kernel implemented by Triton. +GPT-J style +""" + +import torch +import torch_npu +import triton +import triton.language as tl + + +@triton.jit +def rotary_embedding_kernel( + state, # [num_tokens, head_num, head_dim] + cos, # [num_tokens, 1, head_dim // 2] + sin, # [num_tokens, 1, head_dim // 2] + stride_state_n, + stride_state_h, + stride_state_d, + stride_cos_n, + stride_cos_d, + # stride_sin_n, + # stride_sin_d, + num_tokens, + num_heads, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_index = tl.program_id(0) + token_range = token_index * BLOCK_N + tl.arange(0, BLOCK_N) + head_index = tl.program_id(1) + head_range = head_index * BLOCK_H + tl.arange(0, BLOCK_H) + + dim_range = tl.arange(0, BLOCK_D // 2) + dim_range_x = dim_range * 2 + dim_range_y = dim_range * 2 + 1 + + # tl.device_print("dim x", dim_range_x) + # tl.device_print("dim y", dim_range_y) + + state_x_offset = ( + token_range[:, None, None] * stride_state_n + + head_range[None, :, None] * stride_state_h + + dim_range_x[None, None, :] * stride_state_d + ) + + state_y_offset = ( + token_range[:, None, None] * stride_state_n + + head_range[None, :, None] * stride_state_h + + dim_range_y[None, None, :] * stride_state_d + ) + + cos_sim_offset = ( + token_range[:, None, None] * stride_cos_n + + dim_range[None, None, :] * stride_cos_d + ) + + state_x = tl.load( + state + state_x_offset, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + other=0.0, + ) + state_y = tl.load( + state + state_y_offset, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + other=0.0, + ) + + cos_loaded = tl.load( + cos + cos_sim_offset, + mask=token_range[:, None, None] < num_tokens, + other=0.0, + ) + sin_loaded = tl.load( + sin + cos_sim_offset, + mask=token_range[:, None, None] < num_tokens, + other=0.0, + ) + + out_x = state_x * cos_loaded - state_y * sin_loaded + out_y = state_x * sin_loaded + state_y * cos_loaded + + tl.store( + state + state_x_offset, + out_x, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + ) + tl.store( + state + state_y_offset, + out_y, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + ) + + +@torch.inference_mode() +def rotary_embedding(state, cos, sin): + num_tokens = state.shape[0] + num_heads = state.shape[1] + head_dim = state.shape[2] + + BLOCK_N = 8 + BLOCK_H = 4 + grid = ( + triton.cdiv(num_tokens, BLOCK_N), + triton.cdiv(num_heads, BLOCK_H), + ) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + kernel = rotary_embedding_kernel[grid]( + state, + cos, + sin, + state.stride(0), + state.stride(1), + state.stride(2), + cos.stride(0), + cos.stride(2), + # sin.stride(0), + # sin.stride(2), + num_tokens, + num_heads, + BLOCK_N=BLOCK_N, + BLOCK_H=BLOCK_H, + BLOCK_D=head_dim, + num_warps=num_warps, + num_stages=1, + ) + # print(kernel.asm['ttir']) + return + + +def torch_rotary_embedding(state, cos, sin): + _, _, dim = state.shape + state_x = state[:, :, 0 : dim : 2] + state_y = state[:, :, 1 : dim : 2] + out_x = state_x * cos - state_y * sin + out_y = state_x * sin + state_y * cos + out = torch.empty_like(state).npu() + out[:, :, 0 : dim : 2] = out_x + out[:, :, 1 : dim : 2] = out_y + return out + + +def _rotary_emb(dtype): + tokens_num = 256 + num_heads = 96 + head_dim = 128 + max_positions = 1024 + + # torch.float16 has floating point problem in Triton 2.0.0 + # But it works fine in Triton 2.1.0 + state = torch.randn((tokens_num, num_heads, head_dim), dtype=dtype, device="npu") + cos_shape = (tokens_num, 1, head_dim // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="npu") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="npu") + # forward pass + torch_result = torch_rotary_embedding(state, cos, sin) + rotary_embedding(state, cos, sin) + triton_result = state # state is modified in-place + # print(torch_result[1][0]) + # print(triton_result[1][0]) + # Note: This test is not accurate enough. + assert torch.allclose(torch_result, triton_result, atol=1e-2, rtol=1e-7) + +def test_rotary_emb(): + _rotary_emb(torch.float16) + _rotary_emb(torch.float32) diff --git a/third_party/ascend/examples/pytest_ut/test_rotaty_embedding_gpt.py b/third_party/ascend/examples/pytest_ut/test_rotaty_embedding_gpt.py new file mode 100644 index 000000000..2c7f96b3a --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_rotaty_embedding_gpt.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 by Microsoft Corporation. +# Licensed under the MIT license. + +"""Rotary embedding kernel implemented by Triton. + +GPT-NeoX style +""" + +import torch +import torch_npu +import triton +import triton.language as tl + + +@triton.jit +def rotary_embedding_kernel( + state, # [num_tokens, head_num, head_dim] + cos, # [num_tokens, 1, head_dim // 2] + sin, # [num_tokens, 1, head_dim // 2] + stride_state_n, + stride_state_h, + stride_state_d, + stride_cos_n, + stride_cos_d, + # stride_sin_n, + # stride_sin_d, + num_tokens, + num_heads, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_index = tl.program_id(0) + token_range = token_index * BLOCK_N + tl.arange(0, BLOCK_N) + head_index = tl.program_id(1) + head_range = head_index * BLOCK_H + tl.arange(0, BLOCK_H) + + dim_range_x = tl.arange(0, BLOCK_D // 2) + dim_range_y = tl.arange(BLOCK_D // 2, BLOCK_D) + + state_x_offset = ( + token_range[:, None, None] * stride_state_n + + head_range[None, :, None] * stride_state_h + + dim_range_x[None, None, :] * stride_state_d + ) + state_y_offset = ( + token_range[:, None, None] * stride_state_n + + head_range[None, :, None] * stride_state_h + + dim_range_y[None, None, :] * stride_state_d + ) + + cos_sim_offset = ( + token_range[:, None, None] * stride_cos_n + + dim_range_x[None, None, :] * stride_cos_d + ) + + state_x = tl.load( + state + state_x_offset, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + other=0.0, + ) + state_y = tl.load( + state + state_y_offset, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + other=0.0, + ) + + cos_loaded = tl.load( + cos + cos_sim_offset, + mask=token_range[:, None, None] < num_tokens, + other=0.0, + ) + sin_loaded = tl.load( + sin + cos_sim_offset, + mask=token_range[:, None, None] < num_tokens, + other=0.0, + ) + + out_x = state_x * cos_loaded - state_y * sin_loaded + out_y = state_x * sin_loaded + state_y * cos_loaded + + tl.store( + state + state_x_offset, + out_x, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + ) + tl.store( + state + state_y_offset, + out_y, + mask=(token_range[:, None, None] < num_tokens) + & (head_range[None, :, None] < num_heads), + ) + + +@torch.inference_mode() +def rotary_embedding(state, cos, sin): + num_tokens = state.shape[0] + num_heads = state.shape[1] + head_dim = state.shape[2] + + BLOCK_N = 16 + BLOCK_H = 4 + grid = ( + triton.cdiv(num_tokens, BLOCK_N), + triton.cdiv(num_heads, BLOCK_H), + ) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + kernel = rotary_embedding_kernel[grid]( + state, + cos, + sin, + state.stride(0), + state.stride(1), + state.stride(2), + cos.stride(0), + cos.stride(2), + # sin.stride(0), + # sin.stride(2), + num_tokens, + num_heads, + BLOCK_N=BLOCK_N, + BLOCK_H=BLOCK_H, + BLOCK_D=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return + + +def torch_rotary_embedding(state, cos, sin): + _, _, dim = state.shape + state_x = state[:, :, 0 : dim // 2] + state_y = state[:, :, dim // 2 : dim] + out_x = state_x * cos - state_y * sin + out_y = state_x * sin + state_y * cos + return torch.cat((out_x, out_y), dim=-1) + + +def rotary_emb(tokens, heads, headdim, dtype): + tokens_num = tokens + num_heads = heads + head_dim = headdim + #max_positions = 1024 + + # torch.float16 has floating point problem in Triton 2.0.0 + # But it works fine in Triton 2.1.0 + state = torch.randn((tokens_num, num_heads, head_dim), dtype=dtype, device="npu") + cos_shape = (tokens_num, 1, head_dim // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="npu") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="npu") + # forward pass + torch_result = torch_rotary_embedding(state, cos, sin) + rotary_embedding(state, cos, sin) + triton_result = state # state is modified in-place + torch.testing.assert_close(torch_result, triton_result, rtol=1e-3, atol=1e-3) + +def test_cases(): + rotary_emb(256, 96, 128, torch.float16) + rotary_emb(256, 96, 128, torch.float32) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_rshift.py b/third_party/ascend/examples/pytest_ut/test_rshift.py new file mode 100644 index 000000000..e7f0890ef --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_rshift.py @@ -0,0 +1,79 @@ +import pytest +import triton +import triton.language as tl +import time +import test_common +import torch +import torch_npu + + +def standard_unary(x0, dtype): + res = x0 >> 2 + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + tmp = tl.cast(2, tl.int8) + ret = x >> tmp + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + # (torch.float32, 'float32'), + # (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + # print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + # print(out) + + test_common.validate_cmp(sigtype, out, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_rsqrt.py b/third_party/ascend/examples/pytest_ut/test_rsqrt.py new file mode 100644 index 000000000..393de803c --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_rsqrt.py @@ -0,0 +1,42 @@ +import triton +import triton.language as tl +import torch +import numpy as np +import pytest +import test_common + +def numpy_rsqrt(x0, x1): + res = x0 + 1.0 / (np.sqrt(x1)) + return res + +@triton.jit +def triton_rsqrt(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.load(in_ptr1 + (x0), xmask) + tmp2 = tmp0 + tl.rsqrt(tmp1) + tl.store(out_ptr0 + (xindex), tmp2, xmask) + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ]) + +def test_rsqrt(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = np.abs(np.random.randn(shape[0], shape[1], shape[2])).astype(eval('np.' + dtype)) + x1 = np.abs(np.random.randn(shape[0], shape[1], shape[2])).astype(eval('np.' + dtype)) + x0_npu = torch.tensor(x0).npu() + x1_npu = torch.tensor(x1).npu() + # numpy结果 + numpy_res = numpy_rsqrt(x0, x1) + # triton结果 + triton_res = test_common.generate_tensor(shape, dtype).npu() + triton_rsqrt[ncore, 1, 1](x0_npu, x1_npu, triton_res, x0_npu.numel(), xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, triton_res, torch.tensor(numpy_res).npu()) diff --git a/third_party/ascend/examples/pytest_ut/test_scalar_calc.py b/third_party/ascend/examples/pytest_ut/test_scalar_calc.py new file mode 100644 index 000000000..30ba63b4d --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_scalar_calc.py @@ -0,0 +1,753 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common + +### add +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_add_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tmp0 + 2.0 + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = y + 2.0 + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### sub +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_sub_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tmp0 - 2.0 + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = y - 2.0 + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### mul +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_mul_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tmp0 * 2.0 + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = y * 2.0 + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### div +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_div_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tmp0 / 2.0 + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = y / 2.0 + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### remf +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_remf_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tmp0 % 2.0 + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = y % 2.0 + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### negf +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_negf_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = -tmp0 + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = -y + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### cmpf +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_cmpf_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = (tmp0 > 0.5).to(tmp0.dtype) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = (y > 0.5).to(y.dtype) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### ceil +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_ceil_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.math.ceil(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.ceil(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### floor +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_floor_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.math.floor(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.floor(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### maximum(propagate_nan == tl.PropagateNan.ALL) +# setting propagate_nan=tl.PropagateNan.ALL to generate arith::MaximumFOp +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_maximum_nanall_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + tl.static_assert(N > 1) + tmp0 = tl.load(in_ptr0 + 0) + tmp1 = tl.load(in_ptr0 + 1) + tmp1 = tl.maximum(tmp0, tmp1, propagate_nan=tl.PropagateNan.ALL) + tl.store(out_ptr0 + 0, tmp1) + + def torch_func(x0): + y0 = x0[0] + y1 = x0[1] + y = torch.maximum(y0, y1) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### maximum(propagate_nan == tl.PropagateNan.NONE) +# setting propagate_nan=tl.PropagateNan.NONE to generate arith::MaxNumFOp +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_maximum_nannone_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + tl.static_assert(N > 1) + tmp0 = tl.load(in_ptr0 + 0) + tmp1 = tl.load(in_ptr0 + 1) + tmp1 = tl.maximum(tmp0, tmp1, propagate_nan=tl.PropagateNan.ALL) + tl.store(out_ptr0 + 0, tmp1) + + def torch_func(x0): + y0 = x0[0] + y1 = x0[1] + y = torch.fmax(y0, y1) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### minimum(propagate_nan == tl.PropagateNan.ALL) +# setting propagate_nan=tl.PropagateNan.ALL to generate arith::MinimumFOp +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_minimum_nanall_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + tl.static_assert(N > 1) + tmp0 = tl.load(in_ptr0 + 0) + tmp1 = tl.load(in_ptr0 + 1) + tmp1 = tl.minimum(tmp0, tmp1, propagate_nan=tl.PropagateNan.ALL) + tl.store(out_ptr0 + 0, tmp1) + + def torch_func(x0): + y0 = x0[0] + y1 = x0[1] + y = torch.minimum(y0, y1) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### minimum(propagate_nan == tl.PropagateNan.NONE) +# setting propagate_nan=tl.PropagateNan.NONE to generate arith::MinNumFOp +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_minimum_nannone_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + tl.static_assert(N > 1) + tmp0 = tl.load(in_ptr0 + 0) + tmp1 = tl.load(in_ptr0 + 1) + tmp1 = tl.minimum(tmp0, tmp1, propagate_nan=tl.PropagateNan.NONE) + tl.store(out_ptr0 + 0, tmp1) + + def torch_func(x0): + y0 = x0[0] + y1 = x0[1] + y = torch.fmin(y0, y1) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### extf +@pytest.mark.parametrize('param_list', + [ + ['float16', 'float32', 16] + ] + ) +def test_scalar_extf_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tmp0.to(tl.float32) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = y.to(torch.float32) + return torch.tensor(y) + + src_dtype, dst_dtype, N = param_list + x0 = test_common.generate_tensor((N,), src_dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dst_dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dst_dtype, y_cal[0], y_ref) + +### truncf +@pytest.mark.parametrize('param_list', + [ + ['float32', 'float16', 16] + ] + ) +def test_scalar_truncf_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tmp0.to(tl.float16) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = y.to(torch.float16) + return torch.tensor(y) + + src_dtype, dst_dtype, N = param_list + x0 = test_common.generate_tensor((N,), src_dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dst_dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dst_dtype, y_cal[0], y_ref) + +### exp +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_exp_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.math.exp(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.exp(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### exp2 +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_exp_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.math.exp2(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.exp2(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### log +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_log_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp0 = tl.abs(tmp0) + tmp1 = tl.log(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.abs(y) + y = torch.log(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### log2 +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_log2_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp0 = tl.abs(tmp0) + tmp1 = tl.log2(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.abs(y) + y = torch.log2(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### sin +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_sin_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.sin(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.sin(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### cos +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_cos_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.cos(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.cos(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### abs +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_abs_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.abs(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.abs(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### erf +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_erf_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.erf(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.erf(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### sqrt +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_sqrt_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp0 = tl.abs(tmp0) + tmp1 = tl.math.sqrt(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.abs(y) + y = torch.sqrt(y) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### rsqrt +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_rsqrt_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp0 = tl.abs(tmp0) + tmp1 = tl.math.rsqrt(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.abs(y) + y = torch.rsqrt(y) + return y.clone().detach() + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### tanh +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_tanh_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + idx = 0 + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.math.tanh(tmp0) + tl.store(out_ptr0 + idx, tmp1) + + def torch_func(x0): + y = x0[0] + y = torch.tanh(y) + return y.clone().detach() + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) + +### sum +@pytest.mark.parametrize('param_list', + [ + ['float32', 16] + ] + ) +def test_scalar_sum_calc(param_list): + + @triton.jit + def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): + tmp0 = tl.load(in_ptr0 + tl.arange(0, N)) + tmp1 = tl.sum(tmp0, 0) + tl.store(out_ptr0 + 0, tmp1) + + def torch_func(x0): + y = torch.sum(x0, 0) + return torch.tensor(y) + + dtype, N = param_list + x0 = test_common.generate_tensor((N,), dtype).npu() + y_ref = torch_func(x0) + y_cal = test_common.generate_tensor((1,), dtype).npu() + triton_kernel[1, 1, 1](y_cal, x0, N=N) + test_common.validate_cmp(dtype, y_cal[0], y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_sigmoid.py b/third_party/ascend/examples/pytest_ut/test_sigmoid.py new file mode 100644 index 000000000..3d95128d1 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_sigmoid.py @@ -0,0 +1,39 @@ +import triton +import triton.language as tl +import torch +import pytest +import test_common + +def torch_sigmoid(x0, x1): + res = x0 + torch.sigmoid(x1) + return res + +@triton.jit +def triton_sigmoid(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.load(in_ptr1 + (x0), xmask) + tmp2 = tmp0 + tl.sigmoid(tmp1) + tl.store(out_ptr0 + (xindex), tmp2, xmask) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ]) +def test_sigmoid(param_list): + # 生成数据 + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + # torch结果 + y_ref = torch_sigmoid(x0, x1) + # triton结果 + y_cal = test_common.generate_tensor(shape, dtype).npu() + triton_sigmoid[ncore, 1, 1](x0, x1, y_cal, x0.numel(), xblock, xblock_sub) + # 比较结果 + test_common.validate_cmp(dtype, y_cal, y_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_silu.py b/third_party/ascend/examples/pytest_ut/test_silu.py new file mode 100644 index 000000000..283fb8f5b --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_silu.py @@ -0,0 +1,78 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +def standard_unary(x0, dtype): + res = x0 * (1/(1+torch.exp(-x0))) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + ret = x * (1/(1+tl.math.exp(-x))) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + # (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + print(out) + + test_common.validate_cmp(sigtype, out, ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_sin.py b/third_party/ascend/examples/pytest_ut/test_sin.py new file mode 100644 index 000000000..30e32ab34 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_sin.py @@ -0,0 +1,78 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +def standard_unary(x0, dtype): + res = torch.sin(x0) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + ret = tl.sin(x) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + # (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + # (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + print(out) + + test_common.validate_cmp(sigtype, out, ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_softmax.py b/third_party/ascend/examples/pytest_ut/test_softmax.py new file mode 100644 index 000000000..cb3878dde --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_softmax.py @@ -0,0 +1,114 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import test_common +import pytest + +def naive_softmax(x): + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + + +kernels = {} + +def softmax(x, stream): + n_rows, n_cols = x.shape + + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + y = torch.empty_like(x) + + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) + if kernel is None: + num_programs = 32 + kernel = softmax_kernel + kernels[BLOCK_SIZE] = (kernel, num_programs) + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)]( + y, + x, + x.stride(0), + y.stride(0), + n_rows, + n_cols, + BLOCK_SIZE + ) + return y + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + (torch.bfloat16, 'bfloat16'), +] + +shapes = [ + (1823, 781), + (1823, 2), + (1823, 4), + (1823, -32), + (1823, -100), + (1823, -256), +] + +map_for_64_t = {37: 31} + +@pytest.mark.skip(reason="randomly failed") +@pytest.mark.parametrize('dtype, sigtype',types) +@pytest.mark.parametrize('M, N',shapes) +def test_softmax(dtype, sigtype, M, N): + torch_npu.npu.utils.set_device(0) + M = (-M)//torch.tensor(0,dtype=dtype).element_size() if M<0 else M + N = (-N)//torch.tensor(0,dtype=dtype).element_size() if N<0 else N + + if sigtype == 'int64': + M = map_for_64_t[M] if M in map_for_64_t else M + N = map_for_64_t[N] if N in map_for_64_t else N + + device = torch.npu.current_device() + stream = torch.npu.current_stream(device).npu_stream + torch.manual_seed(0) + x = torch.randn(M, N, dtype=dtype, device='npu') + y_triton = softmax(x, stream) + y_torch = torch.softmax(x, axis=1) + test_common.validate_cmp(sigtype, y_triton, y_torch) diff --git a/third_party/ascend/examples/pytest_ut/test_sort.py b/third_party/ascend/examples/pytest_ut/test_sort.py new file mode 100644 index 000000000..c49d39a65 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_sort.py @@ -0,0 +1,66 @@ +import triton +import pytest +import torch +import triton.language as tl +import test_common + +# --------------- +# test sort op +# --------------- + + +@triton.jit +def sort_kernel_2d(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.sort(x, descending=descending, dim=1) + tl.store(Z + off2d, x) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape", [(1, 512), (8, 64), (256, 16), (512, 8)]) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype", ['int8', 'int16', 'float16', 'float32', 'bfloat16']) +def test_sort_2d(shape, descending, dtype): + + x = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch.sort(x, descending=descending)[0] + + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + N = x.shape[0] + M = x.shape[1] + sort_kernel_2d[(1, )](x, triton_res, N, M, descending) + assert (torch_res == triton_res).all(), (torch_res, triton_res) + + +@triton.jit +def sort_kernel_3d(X, Z, D0: tl.constexpr, D1: tl.constexpr, D2: tl.constexpr, descending: tl.constexpr): + off2 = tl.arange(0, D2) + off1 = tl.arange(0, D1) * D2 + off0 = tl.arange(0, D0) * D1 * D2 + + off = off2[None, None, :] + off1[None, :, None] + off0[:, None, None] + x = tl.load(X + off) + + x = tl.sort(x, descending=descending, dim=2) + + tl.store(Z + off, x) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape", [(8, 4, 16)]) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype", ['int8', 'int16', 'float16', 'float32', 'bfloat16']) +def test_sort_3d(shape, descending, dtype): + + x = test_common.generate_tensor(shape, dtype).npu() + torch_res = torch.sort(x, descending=descending)[0] + + triton_res = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + D0 = x.shape[0] + D1 = x.shape[1] + D2 = x.shape[2] + sort_kernel_3d[(1, )](x, triton_res, D0, D1, D2, descending) + assert (torch_res == triton_res).all(), (torch_res, triton_res) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_split.py b/third_party/ascend/examples/pytest_ut/test_split.py new file mode 100644 index 000000000..67226a295 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_split.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import triton +import triton.language as tl + +import torch +import torch_npu +import pytest +import test_common + +@triton.jit +def fn_npu_(output_ptr, x_ptr,output_ptr1,XB : tl.constexpr,YB : tl.constexpr,ZB : tl.constexpr): + xidx=tl.arange(0,XB) + yidx=tl.arange(0,YB) + zidx=tl.arange(0,ZB) + + idx=xidx[:,None,None]*YB*ZB+yidx[None,:,None]*ZB+zidx[None,None,:] + + X = tl.load(x_ptr+idx) + + xx,yy = tl.split(X) + + oidx=xidx[:,None]*YB+yidx[None,:] + + + tl.store(output_ptr+oidx,xx) + tl.store(output_ptr1+oidx,yy) + + + +@pytest.mark.parametrize('para_type,data_type,XB,YB,ZB', + [ + ['float32',torch.float32,16,256,2], + ['float32',torch.float32,8,8,2], + ['float16',torch.float16,16,256,2], + ['float16',torch.float16,8,8,2], + ['int8',torch.int8,8,128,2], + ['int8',torch.int8,8,8,2], + ] + ) +def test_split(para_type,data_type,XB,YB,ZB): + + x = torch.randint(low=-128,high=128,size=(XB,YB,ZB),dtype=data_type).npu() + + a,b = torch.split(x,1,dim = -1) + a = a.reshape(XB,YB) + b = b.reshape(XB,YB) + print(a) + print(b) + + output = torch.randint(1, (XB,YB), dtype=data_type).npu() + output1 = torch.randint(1, (XB,YB), dtype=data_type).npu() + fn_npu_[1,1,1](output,x,output1, XB, YB, ZB, debug = True) + + print(output) + print(output1) + + test_common.validate_cmp(para_type, a, output) + test_common.validate_cmp(para_type, b, output1) + diff --git a/third_party/ascend/examples/pytest_ut/test_sqrt.py b/third_party/ascend/examples/pytest_ut/test_sqrt.py new file mode 100644 index 000000000..f0879cecf --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_sqrt.py @@ -0,0 +1,78 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +def standard_unary(x0, dtype): + res = torch.sqrt(x0) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + ret = tl.math.sqrt(x) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + # (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + + +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + print(out) + + test_common.validate_cmp(sigtype, out, ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_static_print_and_assert.py b/third_party/ascend/examples/pytest_ut/test_static_print_and_assert.py new file mode 100644 index 000000000..f49c134a1 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_static_print_and_assert.py @@ -0,0 +1,164 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common +import functools +import os +import re + +shape = (8,) +XS = 8 +XVALS_INT = [0, + -128, # torch.iinfo(torch.int8).min + 127, # torch.iinfo(torch.int8).max + -32768, # torch.iinfo(torch.int16).min + 32767, # torch.iinfo(torch.int16).max + -2147483648, # torch.iinfo(torch.int32).min + 2147483647, # torch.iinfo(torch.int32).max + 9223372036854775807] # torch.iinfo(torch.int64).max + +XVALS_FP = [0.0000000000e+00, # 0 + 1.1921000009e-07, # torch.finfo(torch.float32).eps + 9.7655999707e-04, # torch.finfo(torch.float16).eps + 7.8125000000e-03, # torch.finfo(torch.bfloat16).eps + 3.4027999388e+38, # torch.finfo(torch.float32).max + 6.5504000000e+04, # torch.finfo(torch.float16).max + 3.3894999515e+38, # torch.finfo(torch.bfloat16).max + 1.0000000000e+00] # 1 + + +def torch_func(x0, x1): + res = x0 + x1 + return res + + +@triton.jit +def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr, print_data_ptr: tl.constexpr, + assert_data_ptr: tl.constexpr): + idx = tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 + tmp1 + tl.static_print(print_data_ptr) + tl.static_assert(assert_data_ptr == assert_data_ptr, "assert_data should equal assert_data") + tl.store(out_ptr0 + idx, tmp2) + + +def triton_func(x0, x1, XS, print_data_ptr, assert_data_ptr): + out = torch.empty_like(x0) + triton_kernel[1, 1, 1](out, x0, x1, XS, print_data_ptr, assert_data_ptr) + return out + + +@pytest.mark.skip(reason="waiting for TA to support") +@pytest.mark.parametrize('sigtype', ['int8']) +@test_common.capture_output("-128") +def test_static_print_int8(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, -128, XVALS_INT[0]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.skip(reason="waiting for TA to support") +@pytest.mark.parametrize('sigtype', ['int16']) +@test_common.capture_output("-32768") +def test_static_print_int16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, -32768, XVALS_INT[2]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.skip(reason="waiting for TA to support") +@pytest.mark.parametrize('sigtype', ['int32']) +@test_common.capture_output("-2147483648") +def test_static_print_int32(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, -2147483648, XVALS_INT[4]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.skip(reason="waiting for TA to support") +@pytest.mark.parametrize('sigtype', ['int64']) +@test_common.capture_output("9223372036854775807") +def test_static_print_int64(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, 9223372036854775807, XVALS_INT[-1]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.skip(reason="waiting for TA to support") +@pytest.mark.parametrize('sigtype', ['float16']) +@test_common.capture_output("1.1921000009e-07") +def test_static_print_float16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_FP[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, 1.1921000009e-07, XVALS_FP[1]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.skip(reason="waiting for TA to support") +@pytest.mark.parametrize('sigtype', ['float32']) +@test_common.capture_output("0.0078125") +def test_static_print_float32(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_FP[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, 7.8125000000e-03, XVALS_FP[0]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.skip(reason="waiting for TA to support") +@pytest.mark.parametrize('sigtype', ['bfloat16']) +@test_common.capture_output("0.00097655999707") +def test_static_print_bfloat16(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_FP[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, 9.7655999707e-04, XVALS_FP[2]) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) + + +@pytest.mark.skip(reason="waiting for TA to support") +@pytest.mark.parametrize('sigtype', ['int8']) +@test_common.capture_output("True") +def test_static_print_bool(capsys, sigtype): + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1, XS, True, True) + test_common.validate_cmp(sigtype, triton_cal, torch_ref) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_store_scalar.py b/third_party/ascend/examples/pytest_ut/test_store_scalar.py new file mode 100644 index 000000000..e8756833b --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_store_scalar.py @@ -0,0 +1,24 @@ +import torch +import torch_npu +import triton +import triton.language as tl + +# load with mask, store with scalar +@triton.jit +def sum_kernel_1(inp, mid, M, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + # 0 / 1 / 2 * 4 + (0,1,2,3) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + inp_ptrs = inp + offset + mask = offset < M + inp_val = tl.load(inp_ptrs, mask=mask).to(tl.float32) + sum_val = tl.sum(inp_val) + mid_ptr = mid + pid + tl.store(mid_ptr, sum_val) + +def test_case(): + inp = torch.ones(16, device="npu", dtype=torch.float32) + mid = torch.empty(4, device="npu", dtype=torch.float32) + sum_kernel_1[(4, 1, 1)](inp, mid, 16, 4) + ref = torch.tensor([4.0, 4.0, 4.0, 4.0], device="npu", dtype=torch.float32) + assert torch.allclose(mid, ref, rtol=1e-03, atol=1e-03, equal_nan=True) diff --git a/third_party/ascend/examples/pytest_ut/test_strides.py b/third_party/ascend/examples/pytest_ut/test_strides.py new file mode 100644 index 000000000..40a2083df --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_strides.py @@ -0,0 +1,122 @@ +import logging +import pytest +import torch +import triton +import triton.language as tl + + +@triton.jit +def mv_kernel( + A, + B, + C, + N, + M, + stride_an, + stride_am, + stride_bm, + stride_cn, + BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid = tl.program_id(0).to(tl.int64) + offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] + offset_m = tl.arange(0, BLOCK_M)[None, :] + n_mask = offset_n < N + A_ptrs = A + offset_n * stride_an + offset_m * stride_am + B_ptrs = B + offset_m * stride_bm + acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) + for m in range(0, M, BLOCK_M): + m_mask = m + offset_m < M + a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) + b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) + acc += a * b + A_ptrs += BLOCK_M * stride_am + B_ptrs += BLOCK_M * stride_bm + + acc = tl.sum(acc, axis=1) + C_ptrs = C + offset_n * stride_cn + tl.store(C_ptrs, acc[:, None], mask=n_mask) + + +def mv(inp, vec): + assert inp.shape[1] == vec.shape[0], "incompatible dimensions" + N, M = inp.shape + out = torch.empty((N,), device=inp.device, dtype=inp.dtype) + + def grid(META): + return (triton.cdiv(N, META["BLOCK_N"]),) + mv_kernel[grid]( + inp, + vec, + out, + N, + M, + inp.stride(0), + inp.stride(1), + vec.stride(0), + out.stride(0), + 32, + 4, + ) + return out + + +def to_reference(inp, upcast=False): + if inp is None: + return None + ref_inp = inp + ref_inp = ref_inp.to("cpu") + if upcast: + if ref_inp.is_complex(): + ref_inp = ref_inp.to(torch.complex128) + else: + ref_inp = ref_inp.to(torch.float64) + return ref_inp + + +RESOLUTION = { + torch.bool: 0, + torch.int16: 0, + torch.int32: 0, + torch.int64: 0, + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, + torch.float64: 1e-7, + torch.complex32: 1e-3, + torch.complex64: 1.3e-6, +} + + +def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1): + assert res.dtype == dtype + ref = ref.to(dtype) + atol = 1e-4 * reduce_dim + rtol = RESOLUTION[dtype] + torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan) + + +def gems_assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1): + res = res.to("cpu") + assert ref.device == torch.device("cpu") + assert_close( + res, ref, dtype, equal_nan=equal_nan, reduce_dim=reduce_dim + ) + + +@pytest.mark.mv +@pytest.mark.parametrize("M, N", [(1, 32), (160, 1024), (5333, 497)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_mv(M, N, dtype): + matrix = torch.randn((N, M), dtype=dtype, device="npu") + vector = torch.randn((M,), dtype=dtype, device="npu") + ref_matrix = to_reference(matrix, True) + ref_vector = to_reference(vector, True) + + ref_out = torch.mv(ref_matrix, ref_vector) + + res_out = mv(matrix, vector) + + gems_assert_close(res_out, ref_out, dtype, reduce_dim=M) + diff --git a/third_party/ascend/examples/pytest_ut/test_sub.py b/third_party/ascend/examples/pytest_ut/test_sub.py new file mode 100644 index 000000000..b539e76c3 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_sub.py @@ -0,0 +1,43 @@ +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + +def torch_pointwise(x0, x1): + res = x0 - x1 + return res + + +@triton.jit +def triton_sub(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 - tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int32', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + triton_sub[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_sum.py b/third_party/ascend/examples/pytest_ut/test_sum.py new file mode 100644 index 000000000..78c4975f3 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_sum.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common +import time + + +@triton.jit +def sum_loop_high(in_ptr0, in_ptr1, in_ptr2, out_ptr0, rnumel, xnumel, + XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr, RBLOCK: tl.constexpr): + R = rnumel + X = xnumel + xoffset = tl.program_id(0) * XBLOCK + xbase = tl.arange(0, XBLOCK_SUB) + rbase = tl.arange(0, RBLOCK) + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + xbase + x0 = xindex[None, :] + _tmp6 = tl.full([RBLOCK, XBLOCK_SUB], 0, tl.float32) + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = None + r1 = rindex[:, None] + tmp0 = tl.load(in_ptr0 + (X * r1 + (x0)), rmask) + tmp1 = tl.load(in_ptr1 + (X * r1 + (x0)), rmask) + tmp3 = tl.load(in_ptr2 + (X * r1 + (x0)), rmask) + tmp2 = tmp0 + tmp1 + tmp4 = tmp2 + tmp3 + _tmp6 = _tmp6 + tmp4 + tmp6 = tl.sum(_tmp6, 0) + tl.store(out_ptr0 + (xindex), tmp6, None) + + +@triton.jit +def sum_loop_low(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, ynumel, + XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + X = xnumel + Y = ynumel + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK) + + x0 = xindex[:, None] + rbase = tl.arange(0, RBLOCK) + _tmp6 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) + for roffset in range(0, ynumel, RBLOCK): + rindex = roffset + rbase + rmask = None + r1 = rindex[None, :] + tmp0 = tl.load(in_ptr0 + (r1 + (Y * x0)), rmask) + tmp1 = tl.load(in_ptr1 + (r1 + (Y * x0)), rmask) + tmp3 = tl.load(in_ptr2 + (r1 + (Y * x0)), rmask) + tmp2 = tmp0 + tmp1 + tmp4 = tmp2 + tmp3 + _tmp6 = _tmp6 + tmp4 + tmp6 = tl.sum(_tmp6, 1) + + tl.store(out_ptr0 + (xindex), tmp6, None) + + +def foo(a, b, c): + y = a + b + c + y = y.sum(0) + return y + + +def bar(a, b, c): + y = a + b + c + y = y.sum(1) + return y + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (64, 8192), 1, 2, 256, 16], + ] + ) +def test_case_1(param_list): + dtype, shape, ncore, XB, YB, ZB = param_list + a = test_common.generate_tensor(shape, dtype).npu() + b = test_common.generate_tensor(shape, dtype).npu() + c = test_common.generate_tensor(shape, dtype).npu() + value = torch.empty_strided((a.shape[0],), (1,)).npu() + + std_low_ret = bar(a, b, c) + print(f"std_low_ret = {std_low_ret[0:8]}") + XBLOCK = 64 + RBLOCK = 32 + NBLOCKS = a.shape[0] // XBLOCK + sum_loop_low[NBLOCKS, 1, 1](a, b, c, value, a.shape[0], a.shape[1], XBLOCK, RBLOCK) + triton_low_ret = value + print(f"triton_low_ret = {triton_low_ret[0:8]}") + torch.testing.assert_close(std_low_ret, triton_low_ret, rtol=1e-3, atol=1e-3) + + std_ret2 = foo(a, b, c) + print(f"std_ret2 = {std_ret2[0:8]}") + NBLOCKS = 32 + XBLOCK = a.shape[1] // NBLOCKS + XBLOCK_SUB = min(64, max(XBLOCK // 2, 32)) + RBLOCK = 64 + + value2 = torch.empty_strided((a.shape[1],), (1,)).npu() + sum_loop_high[NBLOCKS, 1, 1](a, b, c, value2, a.shape[0], a.shape[1], XBLOCK, XBLOCK_SUB, RBLOCK) + triton_ret2 = value2 + print(f"triton_ret2 = {triton_ret2[0:8]}") + torch.testing.assert_close(std_ret2, triton_ret2, rtol=1e-3, atol=1e-3) diff --git a/third_party/ascend/examples/pytest_ut/test_sum_dim0.py b/third_party/ascend/examples/pytest_ut/test_sum_dim0.py new file mode 100644 index 000000000..caa5f7fda --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_sum_dim0.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time + +import torch +import torch_npu +import test_common + +def standard_sum(x0,dim,dtype): + res = torch.sum(x0, dim,dtype = dtype) + return res + +@triton.jit +def triton_sum_dim0(in_ptr0, out_ptr0, M : tl.constexpr, N : tl.constexpr, MNUMEL: tl.constexpr, NNUMEL: tl.constexpr): + mblk_idx = tl.arange(0,MNUMEL) + nblk_idx = tl.arange(0,NNUMEL) + + mmask = mblk_idx [1, 1] + acc_11 = tl.dot(a_vals, b_vals) # [1, 1] + tl.sync_block_wait("cube", "vector", 5) + + # Pointer grid for the single output element: shape [1,1] + c_ptrs = C_ptr + (row + offs_i) * N + (col + offs_j) + + # Store exp(acc) without scalar indexing + tl.store(c_ptrs, tl.exp(acc_11)) + + +@pytest.mark.parametrize( + 'param_list', + [ + # dtype, A-shape, B-shape + ['float32', (4, 4), (4, 4)], + ['float32', (2, 3), (3, 5)], + ] +) +def test_matmul_exp(param_list): + dtype, ashape, bshape = param_list + M, K = ashape + K2, N = bshape + assert K == K2, "Inner dimensions must match" + + # generate input tensors + A = test_common.generate_tensor(ashape, dtype).npu() + B = test_common.generate_tensor(bshape, dtype).npu() + C = test_common.generate_tensor((M, N), dtype).npu() + + # run kernel + grid = (M, N) # one program per output element + triton_matmul_exp[grid](A, B, C, M, N, K) + + # reference result + C_ref = (A @ B).exp() + + # compare + test_common.validate_cmp(dtype, C, C_ref) + +if __name__ == "__main__": + test_matmul_exp('float32', (4, 4), (4, 4)) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_tan.py b/third_party/ascend/examples/pytest_ut/test_tan.py new file mode 100644 index 000000000..ac881a6ef --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_tan.py @@ -0,0 +1,80 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +import triton.language.extra.ascend.libdevice as libdevice + +def standard_unary(x0, dtype): + res = torch.tan(x0) + return res + + +def standard_binary(x0, y0, dtype): + res = x0 + y0 + return res + + +@triton.jit +def triton_elementwise_unary(in_ptr0, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + ret = libdevice.tan(x) + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +@triton.jit +def triton_elementwise_binary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: tl.constexpr): + idx_block = tl.arange(0, NUMEL) + x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) + y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) + ret = x + y + tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) + + +types = [ + (torch.float32, 'float32'), + (torch.float16, 'float16'), + # (torch.bfloat16, 'bfloat16'), + # (torch.int8, 'int8'), + # (torch.int16, 'int16'), + # (torch.int32, 'int32'), + # (torch.int64, 'int64'), +] + +shapes = [ + (3, 32), + (-32, 32), + (37, 64), + (-256, 256), + (781, 1024), +] + +map_for_64_t = {37: 31} + +@pytest.mark.skip(reason="randomly failed accuracy test") +@pytest.mark.parametrize('dtype,sigtype', types) +@pytest.mark.parametrize('N,NUMEL', shapes) +def test_elementwsie_common(dtype, sigtype, N, NUMEL): + N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N + + if sigtype == "int64": + N = map_for_64_t[N] if N in map_for_64_t else N + + print(f"elementwise : ({N},) {dtype} {sigtype}") + + x0 = test_common.generate_tensor(shape=(N,), dtype=sigtype) + + ans = standard_unary(x0, dtype) + x0 = x0.npu() + print(ans) + + out = torch.zeros((N,), dtype=dtype).npu() + triton_elementwise_unary[1, 1, 1](x0, out, N=N, NUMEL=NUMEL, debug=True) + print(out) + + test_common.validate_cmp(sigtype, out, ans) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_template.py b/third_party/ascend/examples/pytest_ut/test_template.py new file mode 100644 index 000000000..b9409639e --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_template.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl + +import time +import torch +import torch_npu +import test_common + +NBLOCKS = 1 +X_SIZE : tl.constexpr = 4 +Y_SIZE : tl.constexpr = 64 +Z_SIZE : tl.constexpr = 32 +NUMEL = X_SIZE * Y_SIZE * Z_SIZE + +def fn(input): + output = input.reshape((X_SIZE, Y_SIZE, Z_SIZE)).permute((1, 0, 2)).reshape((X_SIZE * Y_SIZE * Z_SIZE)) + return output + +@triton.jit +def fn_kernel(output_ptr, input_ptr): + col_offsets = tl.arange(0, X_SIZE * Y_SIZE * Z_SIZE) + input_local = tl.load(input_ptr + col_offsets) + input_local = input_local.reshape((X_SIZE, Y_SIZE, Z_SIZE)).permute((1, 0, 2)).reshape((X_SIZE * Y_SIZE * Z_SIZE)) + tl.store(output_ptr + col_offsets, input_local) + + +def test_cases(): + input = torch.randn(NUMEL, dtype=torch.float16).npu() + output = torch.randn(NUMEL, dtype=torch.float16).npu() + output2 = torch.randn(NUMEL, dtype=torch.float16).npu() + fn_kernel[1,1,1](output, input) + output2 = fn(input) + test_common.validate_cmp('float16', output, output2) + print("data validation passed") diff --git a/third_party/ascend/examples/pytest_ut/test_tensor_descriptor.py b/third_party/ascend/examples/pytest_ut/test_tensor_descriptor.py new file mode 100644 index 000000000..592e25d4f --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_tensor_descriptor.py @@ -0,0 +1,165 @@ +import math +import pytest +import torch +import triton +import triton.language as tl +import test_common + + +@pytest.mark.parametrize("dtype", ['float32', 'float16', 'bfloat16', 'int32', 'int64', 'int16', 'int8']) +@pytest.mark.parametrize("M_BLOCK,N_BLOCK", [(2, 16), (8, 16)]) +def test_tensor_descriptor_load_store(dtype, M_BLOCK, N_BLOCK): + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + in_desc = tl.make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + block = in_desc.load([moffset, noffset]) + out_desc.store([moffset, noffset], block) + + M, N = M_BLOCK * 2, N_BLOCK * 2 + inp = test_common.generate_tensor((M, N), dtype).npu() + out = inp.new_empty((M, N)) + + grid_m = M // M_BLOCK + grid_n = N // N_BLOCK + + kernel[(grid_m, grid_n)](out, inp, M, N, M_BLOCK, N_BLOCK) + torch.testing.assert_close(inp, out) + + +@pytest.mark.parametrize("dtype", ['float32', 'float16', 'bfloat16', 'int32', 'int64', 'int16', 'int8']) +def test_tensor_descriptor_load_store3d(dtype): + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, K, stride_m, stride_n, stride_k, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, + K_BLOCK: tl.constexpr): + in_desc = tl.make_tensor_descriptor( + a_ptr, + shape=[M, N, K], + strides=[stride_m, stride_n, stride_k], + block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + shape=[M, N, K], + strides=[stride_m, stride_n, stride_k], + block_shape=[M_BLOCK, N_BLOCK, K_BLOCK], + ) + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + koffset = tl.program_id(2) * K_BLOCK + block = in_desc.load([moffset, noffset, koffset]) + out_desc.store([moffset, noffset, koffset], block) + + M, N, K = 8, 16, 32 + inp = test_common.generate_tensor((M, N, K), dtype).npu() + out = inp.new_empty((M, N, K)) + + M_BLOCK = 2 + N_BLOCK = 4 + + # 自动调整 K_BLOCK,保证最后一维 block 至少 16 字节 + dtype = getattr(inp, "dtype", None) + itemsize = torch.tensor([], dtype=inp.dtype).element_size() + min_k_block = max(16 // itemsize, 1) + K_BLOCK = max(8, min_k_block) + + grid_m = M // M_BLOCK + grid_n = N // N_BLOCK + grid_k = K // K_BLOCK + + kernel[(grid_m, grid_n, grid_k)](out, inp, *inp.shape, *out.stride(), M_BLOCK, N_BLOCK, K_BLOCK) + torch.testing.assert_close(inp.reshape(M * N * K), out.reshape(M * N * K)) + + +# Exercise the functional load/store builtins once to ensure they map through. +@pytest.mark.parametrize("dtype", ["float32"]) +def test_tensor_descriptor_functional_interface(dtype): + """Copies an entire tensor blockwise using the descriptor builtins.""" + + @triton.jit + def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + in_desc = tl.make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + out_desc = tl.make_tensor_descriptor( + out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + block = tl.load_tensor_descriptor(in_desc, [moffset, noffset]) + tl.store_tensor_descriptor(out_desc, [moffset, noffset], block) + + M, N = 32, 128 + inp = test_common.generate_tensor((M, N), dtype).npu() + + M_BLOCK = 8 + N_BLOCK = 32 + out = inp.new_empty((M, N)) + + grid_m = M // M_BLOCK + grid_n = N // N_BLOCK + + kernel[(grid_m, grid_n)](out, inp, M, N, M_BLOCK, N_BLOCK) + torch.testing.assert_close(inp, out) + + +@pytest.mark.parametrize("dtype_str", ["int32"]) +@pytest.mark.parametrize("shape", [(128, 2, 4), (64, 2, 4), (32, 2, 4), (2, 4, 32), (2, 4, 2)]) +@pytest.mark.parametrize("axis", [0, 1, 2]) +@pytest.mark.parametrize("device", ["npu"]) +def test_reduce_max(dtype_str, shape, axis, device): + + @triton.jit + def kernel( + In, + Out, + in_shape1: tl.constexpr, + in_shape2: tl.constexpr, + in_shape3: tl.constexpr, + ou_shape1: tl.constexpr, + ou_shape2: tl.constexpr, + axis: tl.constexpr, + ): + in_desc = tl.make_tensor_descriptor( + base=In, + shape=[in_shape1 * in_shape2 * in_shape3], + strides=[1], + block_shape=[in_shape1 * in_shape2 * in_shape3], + ) + out_desc = tl.make_tensor_descriptor( + base=Out, + shape=[ou_shape1 * ou_shape2], + strides=[1], + block_shape=[ou_shape1 * ou_shape2], + ) + val = in_desc.load([0]).reshape(in_shape1, in_shape2, in_shape3) + output = tl.max(val, axis=axis) + out_desc.store([0], output.reshape(out_desc.block_shape)) + + inp = torch.arange(math.prod(shape), + dtype=getattr(torch, dtype_str), + device=device).reshape(shape) + expected, indices = torch.max(inp.to(torch.int64), dim=axis) + expected = expected.to(torch.int32) + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](inp, actual, *shape, *expected.shape, axis=axis) + assert torch.equal(expected, actual) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_top2gating_argmax.py b/third_party/ascend/examples/pytest_ut/test_top2gating_argmax.py new file mode 100644 index 000000000..a5b33a778 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_top2gating_argmax.py @@ -0,0 +1,77 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def is_floating(x): + return promote_to_tensor(x).dtype.is_floating() + + +@triton.jit +def maximum_with_index(a_value, a_index, b_value, b_index): + mask = a_value > b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def max_with_index(value, index, dim): + return tl.reduce((value, index), dim, maximum_with_index) + + +@triton.jit +def triton_4(in_ptr2, in_ptr4, out_ptr10, x0_numel, r1_numel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + RBLOCK: tl.constexpr = 4 + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + base2 = tl.arange(0, RBLOCK) + loops2: tl.constexpr = r1_numel // RBLOCK + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1[:, None] + r1 = base2[None, :] + tmp15 = tl.load(in_ptr2 + (r1 + (8 * x0)), None) + tmp25 = tl.load(in_ptr4 + (r1 + (4 * x0)), None) + tmp21 = tmp15.to(tl.float32) + tmp26 = tmp25 * tmp21 + tmp27 = tl.reshape(tmp26, [XBLOCK_SUB, RBLOCK]) + tmp31 = tl.broadcast_to(r1.reshape(1, RBLOCK), tmp27.shape) + _, tmp30_tmp = max_with_index(tmp27, tmp31, 1) + tmp30 = tmp30_tmp.reshape(XBLOCK_SUB, 1) + tmp43 = tmp30.to(tl.int32) + tmp46 = tmp43 + tl.store(out_ptr10 + (x0), tmp46, None) + + +def test_max_with_index_dim0(): + mask = torch.randint(low=0, high=2, size=(512, 2, 4), dtype=torch.int32).npu() + weights = torch.randn((512, 4), device='npu', dtype=torch.float32) + buf32 = torch.randint(low=0, high=8, size=(512,), dtype=torch.int32).npu() + XBLOCK = 32 + XBLOCK_SUB = 32 + triton_4[16, 1, 1](mask, weights, buf32, 512, 4, XBLOCK, XBLOCK_SUB) + + _, first_idx = torch.max(weights * mask[:, 0, :], dim=1) + print(f"first_idx: {first_idx[0:8]}") + print(f"triton_idx: {buf32[0:8]}") + + assert torch.all(torch.eq(first_idx, buf32.to(first_idx.dtype))) diff --git a/third_party/ascend/examples/pytest_ut/test_trans_3d.py b/third_party/ascend/examples/pytest_ut/test_trans_3d.py new file mode 100644 index 000000000..a412f5e4b --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_trans_3d.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import logging +import math + +import pytest +import test_common +import torch + +import triton +import triton.language as tl + + +@triton.jit +def fn_npu_102(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + kidx = tl.arange(0, KB) + idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.trans(X, 1, 0, 2) + + oidx = ( + zidx[:, None, None] * YB * KB + yidx[None, :, None] * KB + kidx[None, None, :] + ) + + tl.store(output_ptr + oidx, ret) + + +@triton.jit +def fn_npu_021(output_ptr, x_ptr, YB: tl.constexpr, ZB: tl.constexpr, KB: tl.constexpr): + yidx = tl.arange(0, YB) + zidx = tl.arange(0, ZB) + kidx = tl.arange(0, KB) + idx = yidx[:, None, None] * ZB * KB + zidx[None, :, None] * KB + kidx[None, None, :] + + X = tl.load(x_ptr + idx) + + ret = tl.trans(X, (0, 2, 1)) + + oidx = ( + yidx[:, None, None] * ZB * KB + kidx[None, :, None] * ZB + zidx[None, None, :] + ) + + tl.store(output_ptr + oidx, ret) + + +@pytest.mark.parametrize("shape", [(23, 5, 31)]) +@pytest.mark.parametrize("dtype", ["float32"]) +def test_permute_3d(shape, dtype): + logging.debug(f"dtype:{dtype} shape:{shape}") + + data_type = eval("torch." + dtype) + x = torch.randint(low=0, high=2, size=shape, dtype=data_type).npu() + + triton_res = torch.empty((shape[1], shape[0], shape[2]), dtype=data_type).npu() + torch_res = torch.permute(x, (1, 0, 2)) + fn_npu_102[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + test_common.validate_cmp(dtype, triton_res, torch_res) + + triton_res = torch.empty((shape[0], shape[2], shape[1]), dtype=data_type).npu() + torch_res = torch.permute(x, (0, 2, 1)) + fn_npu_021[1, 1, 1](triton_res, x, shape[0], shape[1], shape[2]) + test_common.validate_cmp(dtype, triton_res, torch_res) diff --git a/third_party/ascend/examples/pytest_ut/test_triton_eq.py b/third_party/ascend/examples/pytest_ut/test_triton_eq.py new file mode 100644 index 000000000..76d31ad71 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_triton_eq.py @@ -0,0 +1,43 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +def torch_pointwise(x0, x1): + res = x0 == x1 + return res + + +@triton.jit +def triton_test(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 == tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype = eval('torch.' + 'bool')).npu() + triton_test[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_triton_le.py b/third_party/ascend/examples/pytest_ut/test_triton_le.py new file mode 100644 index 000000000..30d1979dd --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_triton_le.py @@ -0,0 +1,43 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +def torch_pointwise(x0, x1): + res = x0 <= x1 + return res + + +@triton.jit +def triton_le(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 <= tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype = eval('torch.' + 'bool')).npu() + triton_le[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_triton_lt.py b/third_party/ascend/examples/pytest_ut/test_triton_lt.py new file mode 100644 index 000000000..22fbde4c4 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_triton_lt.py @@ -0,0 +1,43 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +def torch_pointwise(x0, x1): + res = x0 < x1 + return res + + +@triton.jit +def triton_lt(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 < tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype = eval('torch.' + 'bool')).npu() + triton_lt[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_triton_neq.py b/third_party/ascend/examples/pytest_ut/test_triton_neq.py new file mode 100644 index 000000000..c11bb6e30 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_triton_neq.py @@ -0,0 +1,42 @@ +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + +def torch_pointwise(x0, x1): + res = x0 != x1 + return res + + +@triton.jit +def triton_neq(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + for loop1 in range(loops1): + x0_prime = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tmp0 != tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (2, 4096, 8), 2, 32768, 1024], + ['float16', (2, 4096, 8), 2, 32768, 1024], + ['int8', (2, 4096, 8), 2, 32768, 1024], + ] + ) + +def test_case(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_pointwise(x0, x1) + y_cal = torch.zeros(shape, dtype = eval('torch.' + 'bool')).npu() + triton_neq[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_umulhi.py b/third_party/ascend/examples/pytest_ut/test_umulhi.py new file mode 100644 index 000000000..4dcbcc9f4 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_umulhi.py @@ -0,0 +1,41 @@ +import torch +import torch_npu +import numpy as np +import triton +import triton.language as tl +from numpy.random import RandomState + +# inp the two 32 bit signed integers. +@triton.jit +def umulhi_kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.umulhi(x, y) + tl.store(Z + tl.arange(0, N), z) + + +# accuracy reference +def umulhi32(a, b): + a_64 = a.astype(np.int64) + b_64 = b.astype(np.int64) + product_64 = a_64 * b_64 + # get the high part + result_high_32 = product_64 >> 32 + return result_high_32.astype(np.int32) + + +def test_umulhi(): + N = 128 + x = torch.randint(low=0, high=2000, size=(N,), dtype=torch.int32) + y = torch.randint(low=0, high=2000, size=(N,), dtype=torch.int32) + xx = x.npu() + yy = y.npu() + z_tri = torch.zeros(size=(N,), dtype=torch.int32).npu() + umulhi_kernel[(1,)](xx, yy, z_tri, N=N) + + xxx = x.numpy() + yyy = y.numpy() + z_ref = umulhi32(xxx, yyy) + z_ref1 = torch.from_numpy(z_ref).npu() + torch.equal(z_tri, z_ref1) \ No newline at end of file diff --git a/third_party/ascend/examples/pytest_ut/test_unlign_max_with_index.py b/third_party/ascend/examples/pytest_ut/test_unlign_max_with_index.py new file mode 100644 index 000000000..76f82a9fc --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_unlign_max_with_index.py @@ -0,0 +1,97 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import test_common + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def is_floating(x): + return promote_to_tensor(x).dtype.is_floating() + + +@triton.jit +def maximum_with_index(a_value, a_index, b_value, b_index): + mask = a_value > b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def max_with_index(value, index, dim): + return tl.reduce((value, index), dim, maximum_with_index) + + +def torch_max_pr_index(x0): + res = torch.max(x0, 0) + return res + + +# x0_numel 128, r1_numel 65, 128, 128, 32 +@triton.jit +def triton_kernel(in_ptr0, in_ptr1, out_ptr0, out_ptr1, x0_numel, r1_numel, XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, RBLOCK: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + base2 = tl.arange(0, RBLOCK) + loops2: tl.constexpr = (r1_numel + RBLOCK - 1) // RBLOCK + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1[:, None] + xmask = x0 < x0_numel + for loop2 in range(loops2): + r1 = loop2 * RBLOCK + base2[None, :] + rmask = r1 < r1_numel + tmp0 = tl.load(in_ptr0 + (r1 + (65 * x0)), rmask & xmask, other=0) + tmp2 = tl.load(in_ptr1 + (r1 + (65 * x0)), rmask & xmask, other=0) + tmp11 = tl.reshape(tmp0, [XBLOCK_SUB, RBLOCK]) + tmp22 = tl.reshape(tmp2, [XBLOCK_SUB, RBLOCK]) + tmp4, tmp5 = max_with_index(tmp11, tmp22, 0) + tl.store(out_ptr0 + r1, tmp4[None, :], None) + tl.store(out_ptr1 + r1, tmp5[None, :], None) + + +@pytest.mark.parametrize('param_list', + [ + ['float32', (128, 65), 1, 128, 65, 32], + ] + ) +def test_max_with_index_dim0(param_list): + dtype, shape, ncore, XB, YB, RBLOCK = param_list + import math + numel = math.prod(shape) + xblock = numel // shape[-1] // ncore + assert (ncore * xblock * shape[-1] == numel) + + b = test_common.generate_tensor(shape, dtype).npu() + idx = torch.zeros((shape), device='npu', dtype=torch.int32) + c, d = torch.max(b, dim=0) + + print(f"std_ret = {c[0:4]}") + + ret = torch.empty((YB), device='npu', dtype=torch.float32) + ret1 = torch.zeros((YB), device='npu', dtype=torch.int32) + triton_kernel[ncore, 1, 1](b, idx, ret, ret1, shape[0], shape[1], shape[0], shape[0], RBLOCK) + print(f"triton_ret = {ret[0:4]}") + d = d.to(torch.int32) + ret1 = ret1.to(torch.int32) + print(f"d = {d[0:4]}") + print(f"triton_ret1 = {ret1[0:4]}") + assert torch.allclose(c, ret, rtol=1e-03, atol=1e-03, equal_nan=True) + assert torch.allclose(d, ret1, rtol=1e-03, atol=1e-03, equal_nan=True) diff --git a/third_party/ascend/examples/pytest_ut/test_unlign_sum.py b/third_party/ascend/examples/pytest_ut/test_unlign_sum.py new file mode 100644 index 000000000..3cf7a29ef --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_unlign_sum.py @@ -0,0 +1,42 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import triton +import triton.language as tl + +# [128,65] -> [128,128] +# [128,128] -> [128,65] mask 作用 +# [128,65] -> [128,64] [128,1] mask 作用 + +@triton.jit +def triton_unlign(in_ptr0, out_ptr0, x0_numel, r1_numel, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr, RBLOCK : tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB + base2 = tl.arange(0, RBLOCK) + loops2: tl.constexpr = (r1_numel + RBLOCK - 1) // RBLOCK + for loop1 in range(loops1): + x = offset + (loop1 * XBLOCK_SUB) + base1 + x0 = offset + (loop1 * XBLOCK_SUB) + base1[:, None] + xmask = x0 < x0_numel + _tmp2 = tl.full([XBLOCK_SUB, RBLOCK], 0, tl.float32) + for loop2 in range(loops2): + r1_prime = loop2 * RBLOCK + base2[:, None] + r1 = loop2 * RBLOCK + base2[None, :] + rmask = r1 < r1_numel + tmp0 = tl.load(in_ptr0 + (r1 + (65*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.reshape(tmp0, [XBLOCK_SUB, RBLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tmp3 + tmp2 = tl.sum(_tmp2, 1).reshape(XBLOCK_SUB,1) + tl.store(out_ptr0 + (x0), tmp2, xmask) + +def test_cases(): + size = (128, 65) + b = weights = torch.randn((size), dtype=torch.float32).npu() + c = torch.sum(b, dim=1) + ret = torch.randn((size[0]), device='npu', dtype=torch.float32).npu().reshape(size[0]) + triton_unlign[1, 1, 1](b, ret, size[0], size[1], size[0], size[0], 32) + assert torch.allclose(c, ret, rtol=1e-03, atol=1e-03, equal_nan=True) + diff --git a/third_party/ascend/examples/pytest_ut/test_unused_func_arg.py b/third_party/ascend/examples/pytest_ut/test_unused_func_arg.py new file mode 100644 index 000000000..90f02ff91 --- /dev/null +++ b/third_party/ascend/examples/pytest_ut/test_unused_func_arg.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import triton +import triton.language as tl +import test_common +import torch +import torch_npu +import pytest +import math + +def expand_to_next_power_of_two(a): + if a <= 0: + raise ValueError("must >0") + if (math.log2(a)).is_integer(): + return a + return 2 ** math.ceil(math.log2(a)) + +@triton.jit +def triton_unused_func_arg_kernel( + output_ptr, x_ptr, + X : tl.constexpr, Y : tl.constexpr, Z : tl.constexpr, + XNUMEL : tl.constexpr, YNUMEL : tl.constexpr, ZNUMEL : tl.constexpr): + xidx = tl.arange(0,XNUMEL) + yidx = tl.arange(0,YNUMEL) + zidx = tl.arange(0,ZNUMEL) + Xmask = xidx/dev/null | wc -l) + [ $NPU_DEVICES -eq 0 ] && { + echo "No Ascend devices found!" + exit 1 + } + + echo "Detected $NPU_DEVICES Ascend devices" + + if [ -d ${WORKSPACE}triton ];then + rm -rf ${WORKSPACE}triton + fi + + if [ -d ~/.triton/dump ];then + rm -rf ~/.triton/dump + fi + + if [ -d ~/.triton/cache ];then + rm -rf ~/.triton/cache + fi + + test_dir=$1 + cd ${test_dir} + + # 清理旧日志 + rm -rf logs && mkdir logs + + # 记录测试开始时间 + start_time=$(date +"%Y-%m-%d %H:%M:%S") + echo "===== 测试开始时间: ${start_time} =====" + + # 运行测试并捕获退出状态 + set +e + pytest ${test_dir} -n auto --dist=loadfile -v --junitxml=logs/results.xml | tee logs/raw_output.log + pytest_exit=$? + set -e + + # 处理日志(添加设备标签) + awk ' + />> Worker gw[0-9]+ using NPU device/ { + split($0, parts, / /) + dev_id = parts[NF] + worker = parts[3] + print "[" strftime("%Y-%m-%d %H:%M:%S") "| DEV-" dev_id "] " $0 + next + } + { print "[" strftime("%Y-%m-%d %H:%M:%S") "| DEV-" dev_id "] " $0 } + ' logs/raw_output.log > logs/combined.log + + # 新增:解析测试结果统计 + total_tests=0 + passed_tests=0 + failed_tests=0 + skipped_tests=0 + error_tests=0 + + # 使用Python解析JUnit XML报告 + python3 -c " +import xml.etree.ElementTree as ET +import os + +xml_file = os.path.join('logs', 'results.xml') +if not os.path.exists(xml_file): + print('JUnitXML report not found:', xml_file) + exit(1) + +tree = ET.parse(xml_file) +root = tree.getroot() + +total = 0 +passed = 0 +failed = 0 +skipped = 0 +errors = 0 + +# 遍历所有testsuite +for testsuite in root.findall('testsuite'): + total += int(testsuite.get('tests', 0)) + passed += int(testsuite.get('tests', 0)) - int(testsuite.get('errors', 0)) - int(testsuite.get('failures', 0)) - int(testsuite.get('skipped', 0)) + failed += int(testsuite.get('failures', 0)) + skipped += int(testsuite.get('skipped', 0)) + errors += int(testsuite.get('errors', 0)) + +print(f'total_tests={total}') +print(f'passed_tests={passed}') +print(f'failed_tests={failed}') +print(f'skipped_tests={skipped}') +print(f'error_tests={errors}') +" > logs/stats.tmp + + # 加载统计结果 + source logs/stats.tmp + rm logs/stats.tmp + + # 记录测试结束时间 + end_time=$(date +"%Y-%m-%d %H:%M:%S") + duration=$(( $(date -d "$end_time" +%s) - $(date -d "$start_time" +%s) )) + duration_str=$(printf "%02dh %02dm %02ds" $((duration/3600)) $(((duration%3600)/60)) $((duration%60))) + + # 新增:生成统计摘要 + stats_summary=" +===== generalization_cases测试统计摘要 ===== +测试目录: $(basename ${test_dir}) +测试开始时间: ${start_time} +测试结束时间: ${end_time} +总耗时: ${duration_str} +------------------------ +总用例数: ${total_tests} +成功用例: ${passed_tests} +失败用例: ${failed_tests} +跳过用例: ${skipped_tests} +错误用例: ${error_tests} +成功率: $(( passed_tests * 100 / total_tests ))% (成功/总数) +设备数量: ${NPU_DEVICES} +======================== +" + + # 输出统计信息到控制台 + echo "${stats_summary}" + + # 追加统计信息到summary.txt + echo "${stats_summary}" >> ${SUMMARY_FILE} + + echo "========================================" + echo "All tests completed!" + echo "JUnit Report: logs/results.xml" + echo "Combined Log: logs/combined.log" + echo "统计摘要已追加到: ${SUMMARY_FILE}" + echo "========================================" + + zip_file=$2 + cd ${test_dir}/logs + zip ${zip_file} combined.log + cp ${zip_file} "/home/daily_log" + + # 返回pytest的退出状态 + return $pytest_exit +} + +# build in torch 2.6.0 +source /opt/miniconda3/bin/activate torch_260 +build_triton + +cd ${WORKSPACE} + +# 初始化统计文件 +echo "生成时间: $(date +"%Y-%m-%d %H:%M:%S")" >> ${SUMMARY_FILE} +echo "========================================" >> ${SUMMARY_FILE} + +# run inductor cases +TEST_inductor_cases="${WORKSPACE}/ascend/examples/inductor_cases" +cd ${TEST_inductor_cases} +bash run_inductor_test.sh + +# run gene case +zip_file="test_generalizetion_case_$(date +%Y%m%d).zip" +TEST_generalization="${WORKSPACE}/ascend/examples/generalization_cases" +run_case_by_multi_card ${TEST_generalization} ${zip_file} + +echo "========================================" >> ${SUMMARY_FILE} + +# run flaggems cases +TEST_flaggems_cases="${WORKSPACE}/ascend/examples/flaggems_cases" +cd ${TEST_flaggems_cases} +bash run_flaggems_test.sh + +# copy summary.txt to /home/daily_log +cp ${SUMMARY_FILE} /home/daily_log \ No newline at end of file diff --git a/third_party/ascend/examples/run_test.sh b/third_party/ascend/examples/run_test.sh new file mode 100644 index 000000000..746103c0e --- /dev/null +++ b/third_party/ascend/examples/run_test.sh @@ -0,0 +1,192 @@ +#!/bin/bash + +set -ex + +script=$(readlink -f "$0") +script_dir=$(dirname "$script") + +# skiped script +skip_script=("bench_utils.py" "11-rab_time.py") + +function uninstall_triton_ascend() { + set +e + while true; do + pip3 uninstall triton_ascend -y | grep "Found existing installation" + if [ $? -eq 1 ]; then + echo "All triton_ascend versions are uninstalled" + break + fi + done + set -e +} + +function build_triton() { + + cd ${WORKSPACE} + # Run uninstall once because the while-loop does not stop. No idea why. + # uninstall_triton_ascend + pip3 uninstall triton_ascend -y + + git submodule set-url third_party/triton https://gitee.com/shijingchang/triton.git + git submodule sync && git submodule update --init --recursive + + bash scripts/build.sh ${WORKSPACE}/ascend ${LLVM_BUILD_DIR} 3.2.0 install 0 +} + +function run_pytestcases() { + if [ -d ${HOME}/.triton/dump ]; then + rm -rf ${HOME}/.triton/dump + fi + if [ -d ${HOME}/.triton/cache ]; then + rm -rf ${HOME}/.triton/cache + fi + + cd ${script_dir} + TARGET_DIR="$1" + cd ${TARGET_DIR} + pytest -n 16 --dist=load . || { exit 1 ; } + +} + +function run_pythoncases() { + if [ -d ${HOME}/.triton/dump ]; then + rm -rf ${HOME}/.triton/dump + fi + if [ -d ${HOME}/.triton/cache ]; then + rm -rf ${HOME}/.triton/cache + fi + + cd ${script_dir} + TARGET_DIR="$1" + cd ${TARGET_DIR} + + declare -a pids + declare -A status_map + has_failure=0 + + # 查找并运行所有.py文件 + for test_script in *.py; do + for skip_item in "${skip_script[@]}"; do + if [ "$test_script" == "$skip_item" ]; then + break + fi + done + + if [ -f "$test_script" ]; then + echo "启动测试: $test_script" + python "./$test_script" & + pid=$! + pids+=($pid) + status_map[$pid]=$test_script + fi + done + + # 等待所有后台进程完成并检查状态 + for pid in "${pids[@]}"; do + wait "$pid" + exit_status=$? + script_name=${status_map[$pid]} + + if [ $exit_status -ne 0 ]; then + echo "[失败] $script_name - 退出码 $exit_status" + has_failure=1 + else + echo "[成功] $script_name" + fi + done + + echo "--------------------------------" + + # 根据测试结果退出 + if [ $has_failure -eq 1 ]; then + echo "部分测试失败!" + exit 1 + else + echo "所有测试通过!" + exit 0 + fi +} + +function validate_git_commit_title() { + if [ $# -lt 1 ]; then + echo "Usage: $0 " + exit 1 + fi + commit_title=$1 + if ! echo "${commit_title}" | grep -qE "^(feat|fix|docs|style|refactor|test|chore|revert)(\(.*\))?: .+"; then + echo "❌ The git commit title does not comply with the specifications!" + echo "Format Requirements: (): " + echo "e.g.: feat(user): The login function is added." + echo "Allowed Types: feat | fix | docs | style | refactor | test | chore | revert" + exit 1 + fi + echo "✅ The submitted information complies with the specifications." +} + +function validate_pr_all_commits_title() { + commit_titles=$(git log master..HEAD --oneline | sed 's/^[^ ]* //') + if [ -z "$commit_titles" ]; then + echo "No commits found between HEAD and master." + exit 1 + fi + echo "Validating commit titles..." + echo "----------------------------" + while IFS= read -r title; do + echo "Checking: $title" + if ! validate_git_commit_title "$title" 2>/dev/null; then + echo "Error in commit: $title" >&2 + HAS_ERROR=true + fi + done <<< "$commit_titles" + if [ "$HAS_ERROR" = true ]; then + echo "----------------------------" + echo "❌ Some commit titles do not meet the specifications." >&2 + exit 1 + else + echo "----------------------------" + echo "✅ All commit titles meet the specifications." + fi +} + +# if ! validate_pr_all_commits_title 2>/dev/null; then +# exit 1 +# fi + +source /usr/local/CANN_8.2.RC1.alpha002/ascend-toolkit/set_env.sh +export LLVM_BUILD_DIR=/opt/llvm-b5cc222 + +# FIXME: 20250508 the bishengir-compile in the CANN 8.0.T115 fails lots of cases +# So we need to use another version of compiler. +COMPILER_ROOT=/home/shared/bisheng_toolkit_20250922 +BSIR_COMPILE_PATH=$(find "$COMPILER_ROOT" -name "bishengir-compile" | xargs dirname) +export PATH=${COMPILER_ROOT}:${BSIR_COMPILE_PATH}:$PATH +# FIXME: the 20250812 bishengir-compile requires the pairing bisheng compiler +export BISHENG_INSTALL_PATH=/home/shared/cann_compiler_20250812/compiler/ccec_compiler/bin + +# build in torch 2.6.0 +source /opt/miniconda3/bin/activate torch_260 +build_triton + +echo "Run ttir to linalg tests..." +cd ${WORKSPACE}/build/cmake.linux-aarch64-cpython-3.11 +ninja check-triton-adapter-lit-tests +if [ $? -eq 0 ]; then + echo "All ttir to linalg tests passed" +else + echo "Some ttir to linalg tests failed" + exit 1 +fi + +pytestcase_dir=("pytest_ut") +for test_dir in "${pytestcase_dir[@]}"; do + echo "run pytestcase in ${test_dir}" + run_pytestcases ${test_dir} +done + +pythoncase_dir=("autotune_cases" "benchmark_cases" "tutorials") +for test_dir in "${pythoncase_dir[@]}"; do + echo "run pythoncase in ${test_dir}" + run_pythoncases ${test_dir} +done + + diff --git a/third_party/ascend/examples/tutorials/01-vector-add.py b/third_party/ascend/examples/tutorials/01-vector-add.py new file mode 100644 index 000000000..29e93bba4 --- /dev/null +++ b/third_party/ascend/examples/tutorials/01-vector-add.py @@ -0,0 +1,79 @@ +""" +Vector Addition +=============== + +In this tutorial, you will write a simple vector addition using Triton. + +In doing so, you will learn about: + +* The basic programming model of Triton. + +* The `triton.jit` decorator, which is used to define Triton kernels. + +* The best practices for validating and benchmarking your custom ops against native reference implementations. + +""" + +# %% +# Compute Kernel +# -------------- + +import torch +import torch_npu + +import triton +import triton.language as tl + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +# %% +# Let's also declare a helper function to (1) allocate the `z` tensor +# and (2) enqueue the above kernel with appropriate grid/block sizes: + +def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + return output + + +# %% +# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device='npu') +y = torch.rand(size, device='npu') +output_torch = x + y +output_triton = add(x, y) +print(output_torch) +print(output_triton) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') diff --git a/third_party/ascend/examples/tutorials/02-fused-softmax.py b/third_party/ascend/examples/tutorials/02-fused-softmax.py new file mode 100644 index 000000000..c4300c730 --- /dev/null +++ b/third_party/ascend/examples/tutorials/02-fused-softmax.py @@ -0,0 +1,110 @@ +""" +Fused Softmax +============= +""" + +import torch +import torch_npu +import triton +import triton.language as tl +from triton.runtime import driver + + +def naive_softmax(x): + """Compute row-wise softmax of X using native pytorch + + We subtract the maximum element in order to avoid overflows. Softmax is invariant to + this shift. + """ + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + + +target = triton.runtime.driver.active.get_current_target() +kernels = {} + +def softmax(x, stream): + n_rows, n_cols = x.shape + + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + # Allocate output + y = torch.empty_like(x) + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) + if kernel is None: + num_programs = 32 + kernel = softmax_kernel + kernels[BLOCK_SIZE] = (kernel, num_programs) + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)]( + y, + x, + x.stride(0), + y.stride(0), + n_rows, + n_cols, + BLOCK_SIZE + ) + return y + + +# %% +# Unit Test +# --------- + +# %% +# We make sure that we test our kernel on a matrix with an irregular number of rows and columns. +# This will allow us to verify that our padding mechanism works. +device = torch.npu.current_device() +stream = torch.npu.current_stream(device).npu_stream +torch.manual_seed(0) +x = torch.randn(1823, 781, device='npu') +y_triton = softmax(x, stream) +y_torch = torch.softmax(x, axis=1) +assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) +print(y_triton) +print(y_torch) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(y_triton-y_torch))}') \ No newline at end of file diff --git a/third_party/ascend/examples/tutorials/03-layer-norm.py b/third_party/ascend/examples/tutorials/03-layer-norm.py new file mode 100644 index 000000000..85516ea95 --- /dev/null +++ b/third_party/ascend/examples/tutorials/03-layer-norm.py @@ -0,0 +1,108 @@ +""" +Layer Normalization +============= +""" + +import pytest +import torch +import triton +import triton.language as tl +import torch_npu + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +@torch.inference_mode() +def layer_norm(x, normalized_shape, weight, bias, eps=1e-5): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + kernel = _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + # print(kernel.asm['ttir']) + return y + + +def _layer_norm(M, N, dtype, eps=1e-5, device='npu'): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(True) + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + # compare + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + print(f"y_tri: {y_tri}") + print(f"y_ref: {y_ref}") + print(f"Layer Normalization {M},{N} {dtype} PASSED!") + + +if __name__ == "__main__": + _layer_norm(128, 128, torch.float16) + _layer_norm(128, 128, torch.bfloat16) + _layer_norm(128, 128, torch.float32) diff --git a/third_party/ascend/examples/tutorials/04-fused-attention.py b/third_party/ascend/examples/tutorials/04-fused-attention.py new file mode 100644 index 000000000..e7d673284 --- /dev/null +++ b/third_party/ascend/examples/tutorials/04-fused-attention.py @@ -0,0 +1,354 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Credits: OpenAI kernel team + +Extra Credits: + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +""" + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +DEVICE = "npu" + + +@triton.jit +def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, query vector + K_block_ptr, V_block_ptr, # Key and value block pointers for current stage + start_m, qk_scale, # Starting position of current query block, qk scale factor + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # Block size constants + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # Current stage flag, m and n offset indices + N_CTX: tl.constexpr, fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision + # Set the processing range [lo, hi) for the current stage (in column block units) + # causal = true + # stage = 1 + # Causal attention, as the name implies, restricts the flow of information during computation, + # only allowing the model to see the current and previous positions. + # In other words, the output at the current position can only depend on the input at or before this position, + # and cannot access information from future positions. + # Causal attention ensures sequential order and prevents "leakage of future information." + # But the following logic will also be triggered + if STAGE == 1: + # Stage 1: process all tokens before the query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + # Stage 2: process the current query block + tl.static_assert(BLOCK_M >= BLOCK_N) + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) # Align starting position + # causal = False (no need for masking) + else: + lo, hi = 0, N_CTX # Process the entire context + + # Adjust K and V block pointers to the starting position `lo` + K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) # K is [HEAD_DIM, N_CTX], shift along the second dim by lo + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # V is [N_CTX, HEAD_DIM], shift along the first dim by lo + + # Index mapping for the accumulator , used for slicing when HEAD_DIM >= 256 + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + + # Iterate over all k, v blocks in the current stage and accumulate the output + for start_n in range(lo, hi, BLOCK_N): # Process BLOCK_N columns at a time + start_n = tl.multiple_of(start_n, BLOCK_N) # Align column start position + # -- Compute qk ---- + k = tl.load(K_block_ptr) + # Modify K + trans_k = tl.trans(k) + qk = tl.dot(q, trans_k) + # Apply causal mask for STAGE 2 + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) # Construct upper triangular mask + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) # Set invalid positions to -∞ + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Update m_ij = max(m_i, max(qk)) + qk -= m_ij[:, None] # Subtract max for softmax stability + else: + qk = qk * qk_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max + qk = qk - m_ij[:, None] # Stabilize + + # Softmax weights p = exp(qk) + p = tl.math.exp(qk) + + # Convert softmax weight type depending on FP8 usage + if fp8_v: + p_cast = p.to(tl.float8e5) # Convert to FP8 format (save memory) + else: + p_cast = p.to(k.dtype) + + v = tl.load(V_block_ptr) # Load corresponding V block + pv = tl.dot(p_cast, v) + l_ij = tl.sum(p, 1) # Softmax denominator (sum of each row) + # -- Update m_i and l_i + alpha = tl.math.exp(m_i - m_ij) # Update factor: exp difference between old and new max + l_i = l_i * alpha + l_ij # Update softmax denominator + # -- Update output accumulator -- + if HEAD_DIM < 256: + acc_ptr = acc_ptr * alpha[:, None] + acc_ptr = tl.dot(p_cast, v, acc_ptr) + else: + # 1. Load current slice of accumulator + acc = tl.load(acc_ptr + block2d_acc) + # 2. Update in slices (split by 1/4 of BLOCK_M to avoid ub overflow) + for i in range(4): + # Calculate start/end rows for current slice + offset = i * (BLOCK_M // 4) + # Extract slice data + acc_i = tl.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + alpha_i = tl.extract_slice(alpha, [offset], [BLOCK_M // 4], [1]) + pv_i = tl.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # Incrementally update slice: acc = acc * alpha + pv + acc_i = acc_i * alpha_i[:, None] + pv_i + # Write updated slice back to accumulator + acc = tl.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + # 3. updated accumulator + tl.store(acc_ptr + block2d_acc, acc) + + m_i = m_ij # Update current block max + # Advance V and K block pointers to next BLOCK_N range + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) + # Return accumulated output acc_ptr, softmax denominator l_i, and max value m_i + return acc_ptr, l_i, m_i + + +@triton.jit +def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, + stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, stride_qk: tl.constexpr, + stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, stride_kk: tl.constexpr, + stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, stride_vk: tl.constexpr, + stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, stride_on: tl.constexpr, + Z: tl.constexpr, H: tl.constexpr, + N_CTX: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr + ): + # Total number of blocks in sequence dimension (M) + NUM_BLOCKS_M = N_CTX // BLOCK_M + # Total tasks = number of sequence blocks × batch size (Z) × number of attention heads (H) + NUM_BLOCKS = NUM_BLOCKS_M * Z * H + + # Current M-dimension block index + pid = tl.program_id(0) + + for block_idx in range(pid, NUM_BLOCKS, 20): + task_hz_idx = block_idx // NUM_BLOCKS_M + task_m_idx = block_idx % NUM_BLOCKS_M + off_z = task_hz_idx // H + off_h = task_hz_idx % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + # Create block pointers for Q, K, V, Output + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, HEAD_DIM), + strides=(stride_om, stride_on), + offsets=(task_m_idx * BLOCK_M, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + # Initialize offsets + offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + + # Initialize accumulator + if HEAD_DIM < 256: + acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + else: + acc_offset = ( + off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + + task_m_idx * BLOCK_M * HEAD_DIM + ) + acc_ptr = acc + acc_offset + + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # + task_m_idx, sm_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + ) + + m_i += tl.math.log(l_i) + if HEAD_DIM < 256: + accumulator = acc_ptr / l_i[:, None] + else: + row = tl.arange(0, BLOCK_M)[:, None] + col_head_dim = tl.arange(0, HEAD_DIM)[None, :] + block2d_acc = row * HEAD_DIM + col_head_dim + accumulator = tl.load(acc_ptr + block2d_acc) + accumulator = accumulator / l_i[:, None] + + m_ptrs = M + task_hz_idx * N_CTX + offs_m + + tl.store(m_ptrs, m_i) + tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, BM, BN): + """ + Forward computation interface: + Args: + ctx: Context object + q: Query tensor (Q), shape [Z, H, N_CTX, HEAD_DIM] + k: Key tensor (K), shape [Z, H, N_CTX, HEAD_DIM] + v: Value tensor (V), shape [Z, H, N_CTX, HEAD_DIM] + causal: Whether to enable causal attention + sm_scale: Scaling factor for QK product + BM: Q block size (BLOCK_M) + BN: K/V block size (BLOCK_N) + Returns: + o: Attention output tensor, shape [Z, H, N_CTX, HEAD_DIM] + """ + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + o = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + + + # Number of NPU cores (adjust based on hardware) + num_cores = 20 + acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[(num_cores,)]( + q, k, v, M, o, acc, sm_scale, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], N_CTX=q.shape[2], + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BM, + BLOCK_N=BN, + STAGE=stage, + **extra_kern_args) + + ctx.save_for_backward(q, k, v, o, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + return o + +attention = _attention.apply + + +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN", [ + (1, 1, 128, 128, False, torch.float16, 32, 128), + (1, 1, 128, 128, False, torch.bfloat16, 64, 128), + (1, 2, 128, 128, False, torch.float16, 32, 128), + (1, 2, 256, 256, False, torch.bfloat16, 32, 256), + (2, 2, 128, 256, False, torch.float16, 64, 128), + (4, 32, 32, 64, False, torch.bfloat16, 32, 32), + (4, 32, 64, 64, False, torch.float16, 32, 64), + (4, 32, 1024, 64, False, torch.bfloat16, 64, 64), + (4, 32, 4096, 64, False, torch.float16, 64, 64), +]) +def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN): + # 过滤非整切案例, N_CTX 需整除 BM 和 BN, 且 HEAD_DIM 需整除 16 + if N_CTX % BM != 0 or N_CTX % BN != 0 or HEAD_DIM % 16 != 0: + pytest.skip("Skipping non-divisible case") + + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + + sm_scale = 0.5 + + tri_out = attention(q, k, v, causal, sm_scale, BM, BN) + ref_out = torch_npu.npu_fusion_attention( + q, k, v, H, + padding_mask=None, + atten_mask=None, + scale=sm_scale, + keep_prob=1.0, + input_layout="BNSD", + pre_tockens=65535, + next_tockens=65535, + sparse_mode=0, + )[0] + + try: + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2, equal_nan=True) + print(f"Test Fused-Attention PASS!") + except AssertionError as e: + print(f"Test Fused-Attention FAILED: ({Z},{H},{N_CTX},{HEAD_DIM}), causal={causal}, dtype={dtype}, BM={BM}, BN={BN}") + print(f"Error: {e}") + + +if __name__ == "__main__": + test_op(1, 1, 128, 128, causal=False, dtype=torch.float16, BM=32, BN=128) + test_op(1, 1, 128, 128, causal=False, dtype=torch.bfloat16, BM=64, BN=128) + test_op(1, 2, 128, 128, causal=False, dtype=torch.float16, BM=32, BN=128) + test_op(1, 2, 256, 256, causal=False, dtype=torch.bfloat16, BM=32, BN=256) + test_op(2, 2, 128, 256, causal=False, dtype=torch.float16, BM=64, BN=128) + test_op(4, 32, 32, 64, causal=False, dtype=torch.bfloat16, BM=32, BN=32) + test_op(4, 32, 64, 64, causal=False, dtype=torch.float16, BM=32, BN=64) + test_op(4, 32, 1024, 64, causal=False, dtype=torch.bfloat16, BM=64, BN=64) + test_op(4, 32, 4096, 64, causal=False, dtype=torch.float16, BM=64, BN=64) diff --git a/third_party/ascend/examples/tutorials/05-matrix-multiplication-flagtree.py b/third_party/ascend/examples/tutorials/05-matrix-multiplication-flagtree.py new file mode 100644 index 000000000..edb9f75da --- /dev/null +++ b/third_party/ascend/examples/tutorials/05-matrix-multiplication-flagtree.py @@ -0,0 +1,197 @@ +""" +Matrix Multiplication (Flagtree Hints Version) +=============== +""" + +import triton +import triton.language as tl +import torch +import torch_npu + +DEV = "npu" +activation = "leaky_relu_custom" + + +def get_autotune_config(): + return [ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}), + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}), + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + ACTIVATION: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + GROUP_SIZE_M: tl.constexpr = 1 + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs_base = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs_base = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + msk_m = offs_am < M + msk_n = offs_bn < N + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a_ptrs = a_ptrs_base + k * BLOCK_SIZE_K * stride_ak + b_ptrs = b_ptrs_base + k * BLOCK_SIZE_K * stride_bk + a = tl.load( + a_ptrs, + mask=msk_m[:, None] and (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=msk_n[None, :] and (offs_k[:, None] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + # Original vector operations + # if ACTIVATION == "leaky_relu_custom": + # accumulator = leaky_relu_custom(accumulator) + # c = accumulator.to(tl.float16) + # # ----------------------------------------------------------- + # # Write back the block of the output matrix C with masks. + # offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + # offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + # tl.store(c_ptrs, c, mask=c_mask) + # Comment out the following lines to enable split the workload to two vector cores using flagtree hints + SUB_BLK_M: tl.constexpr = BLOCK_SIZE_M // 2 + for s in range(0, 2): # @hint: bind_sub_block + vec_sub_blk = tl.extract_slice( + accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1) + ) + if ACTIVATION == "leaky_relu_custom": + vec_sub_blk = leaky_relu_custom(vec_sub_blk) + c_sub_blk = vec_sub_blk.to(tl.float16) + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + s * SUB_BLK_M + tl.arange(0, SUB_BLK_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c_sub_blk, mask=c_mask) + + +# We can fuse `leaky_relu_custom` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu_custom(x): + return tl.where(x >= 0, x, 0.01 * x) + 1.0 + + +def torch_matmul(a, b, activation=""): + c = torch.matmul(a, b) + if activation == "leaky_relu_custom": + c = torch.where(c >= 0, c, 0.01 * c) + 1.0 + return c + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + ACTIVATION=activation, # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). +torch.npu.set_device(1) +torch.manual_seed(0) +a = torch.randn((512, 512), device=DEV, dtype=torch.float16) +b = torch.randn((512, 512), device=DEV, dtype=torch.float16) +triton_output = matmul(a, b, activation) +torch_output = torch_matmul(a, b, activation) +print(f"triton_output_with_fp16_inputs={triton_output}") +print(f"torch_output_with_fp16_inputs={torch_output}") diff --git a/third_party/ascend/examples/tutorials/05-matrix-multiplication.py b/third_party/ascend/examples/tutorials/05-matrix-multiplication.py new file mode 100644 index 000000000..befa7551e --- /dev/null +++ b/third_party/ascend/examples/tutorials/05-matrix-multiplication.py @@ -0,0 +1,197 @@ +""" +Matrix Multiplication +=============== +""" + +import triton +import triton.language as tl +import torch +import torch_npu + +DEV = "npu" +activation = "leaky_relu_custom" + + +def get_autotune_config(): + return [ + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}), + triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}), + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + ACTIVATION: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + GROUP_SIZE_M: tl.constexpr = 1 + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs_base = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs_base = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + msk_m = offs_am < M + msk_n = offs_bn < N + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a_ptrs = a_ptrs_base + k * BLOCK_SIZE_K * stride_ak + b_ptrs = b_ptrs_base + k * BLOCK_SIZE_K * stride_bk + a = tl.load( + a_ptrs, + mask=msk_m[:, None] and (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=msk_n[None, :] and (offs_k[:, None] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + # Original vector operations + # if ACTIVATION == "leaky_relu_custom": + # accumulator = leaky_relu_custom(accumulator) + # c = accumulator.to(tl.float16) + # # ----------------------------------------------------------- + # # Write back the block of the output matrix C with masks. + # offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + # offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + # c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + # tl.store(c_ptrs, c, mask=c_mask) + # Comment out the following lines to enable split the workload to two vector cores + SUB_BLK_M: tl.constexpr = BLOCK_SIZE_M // 2 + for s in tl.parallel(0, 2, bind_sub_block=True): + vec_sub_blk = tl.extract_slice( + accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1) + ) + if ACTIVATION == "leaky_relu_custom": + vec_sub_blk = leaky_relu_custom(vec_sub_blk) + c_sub_blk = vec_sub_blk.to(tl.float16) + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + s * SUB_BLK_M + tl.arange(0, SUB_BLK_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c_sub_blk, mask=c_mask) + + +# We can fuse `leaky_relu_custom` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu_custom(x): + return tl.where(x >= 0, x, 0.01 * x) + 1.0 + + +def torch_matmul(a, b, activation=""): + c = torch.matmul(a, b) + if activation == "leaky_relu_custom": + c = torch.where(c >= 0, c, 0.01 * c) + 1.0 + return c + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + ACTIVATION=activation, # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). +torch.npu.set_device(1) +torch.manual_seed(0) +a = torch.randn((512, 512), device=DEV, dtype=torch.float16) +b = torch.randn((512, 512), device=DEV, dtype=torch.float16) +triton_output = matmul(a, b, activation) +torch_output = torch_matmul(a, b, activation) +print(f"triton_output_with_fp16_inputs={triton_output}") +print(f"torch_output_with_fp16_inputs={torch_output}") diff --git a/third_party/ascend/examples/tutorials/06-demo-autotune.py b/third_party/ascend/examples/tutorials/06-demo-autotune.py new file mode 100644 index 000000000..7b017774d --- /dev/null +++ b/third_party/ascend/examples/tutorials/06-demo-autotune.py @@ -0,0 +1,58 @@ +""" +Autotune +============= +""" +import torch, torch_npu +import triton +import triton.language as tl + +def test_triton_autotune(): + # Return a set of different kernel configurations for autotune + def get_autotune_config(): + return [ + triton.Config({'XS': 1 * 128, 'multibuffer': True}), + triton.Config({'XS': 12 * 1024, 'multibuffer': True}), + triton.Config({'XS': 12 * 1024, 'multibuffer': False}), + triton.Config({'XS': 8 * 1024, 'multibuffer': True}), + ] + # Use @autotune decorator to automatically select the best kernel configuration + @triton.autotune( + configs=get_autotune_config(), # List of configurations + key=["numel"], # the change of numel will trigger autotuning + ) + @triton.jit + def triton_calc_kernel( + out_ptr0, in_ptr0, in_ptr1, numel, + XS: tl.constexpr # Block size controlling how many elements each thread block processes + ): + pid = tl.program_id(0) # Get current program ID + idx = pid * XS + tl.arange(0, XS) # Index range handled by current thread block + msk = idx < numel # Mask to avoid out-of-bound access + for i in range(10000): + tmp0 = tl.load(in_ptr0 + idx, mask=msk, other=0.0) # Load x0 + tmp1 = tl.load(in_ptr1 + idx, mask=msk, other=0.0) # Load x1 + tmp2 = tl.math.exp(tmp0) + tmp1 + i + tl.store(out_ptr0 + idx, tmp2, mask=msk) # Store result + # Function to call the Triton kernel with autotuned configuration + def triton_calc_func(x0, x1): + n = x0.numel() + y0 = torch.empty_like(x0) + grid = lambda meta: (triton.cdiv(n, meta["XS"]), 1, 1) + triton_calc_kernel[grid](y0, x0, x1, n) + return y0 + # Reference implementation using PyTorch for correctness check + def torch_calc_func(x0, x1): + return torch.exp(x0) + x1 + 10000-1 + + DEV = "npu" + DTYPE = torch.float32 + N = 192 * 1024 + x0 = torch.randn((N,), dtype=DTYPE, device=DEV) + x1 = torch.randn((N,), dtype=DTYPE, device=DEV) + torch_ref = torch_calc_func(x0, x1) + triton_cal = triton_calc_func(x0, x1) + torch.testing.assert_close(triton_cal, torch_ref) + +if __name__ == "__main__": + test_triton_autotune() + print("success: test_triton_autotune") diff --git a/third_party/ascend/examples/tutorials/07-profiler.py b/third_party/ascend/examples/tutorials/07-profiler.py new file mode 100644 index 000000000..3b7993c7b --- /dev/null +++ b/third_party/ascend/examples/tutorials/07-profiler.py @@ -0,0 +1,172 @@ +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + + +def profiler_wrapper(fn, *args): + result_path = "./result_profiling" + skip_first = 10 + wait = 0 + warmup = 3 + active = 30 + repeat = 1 + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, + skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), + record_shapes=True, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for i in range(skip_first + (wait + warmup + active) * repeat): + fn(*args) + prof.step() + stream.synchronize() + + +def test_add(x0, x1): + def torch_func(x0, x1): + res = x0 + x1 + return res + + @triton.jit + def triton_kernel_add(out_ptr0, in_ptr0, in_ptr1, + XS: tl.constexpr): + idx = tl.arange(0, XS) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 + tmp1 + tl.store(out_ptr0 + idx, tmp2) + + def triton_func(x0, x1): + y0 = torch.empty_like(x0) + triton_kernel_add[1, 1, 1](y0, x0, x1, N) + return y0 + + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1) + torch.testing.assert_close(triton_cal, torch_ref) + + def wrapper_func(x0, x1): + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1) + + profiler_wrapper(wrapper_func, x0, x1) + + +def test_or(x0, x1): + def torch_func(x0, x1): + res = x0 | x1 + return res + + @triton.jit + def triton_kernel_or(out_ptr0, in_ptr0, in_ptr1, + XS: tl.constexpr): + idx = tl.arange(0, XS) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 | tmp1 + tl.store(out_ptr0 + idx, tmp2) + + def triton_func(x0, x1): + y0 = torch.empty_like(x0) + triton_kernel_or[1, 1, 1](y0, x0, x1, N) + return y0 + + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1) + torch.testing.assert_close(triton_cal, torch_ref) + + def wrapper_func(x0, x1): + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1) + + profiler_wrapper(wrapper_func, x0, x1) + +def test_inductor_add(x0, x1): + # torch_npu._inductor requires torch_npu 2.6.0+ experimental version + import torch_npu._inductor + + def torch_func(x0, x1): + res = x0 + x1 + return res + + compiled_func = torch.compile(torch_func, backend="inductor") + profiler_wrapper(compiled_func, x0, x1) + print("[INFO] Check ./result_profiling directory to find the kernel_details.csv file. " + " Check the columns: Input Shapes,Input Data Types,Input Formats") + +if __name__ == "__main__": + test_case_is_inductor = False + N = 1024 + low = 1 + high = 100 + + # float32 + x0_fp32 = torch.rand((N,), dtype=torch.float32).npu() + x1_fp32 = torch.rand((N,), dtype=torch.float32).npu() + + # float16 + x0_fp16 = torch.rand((N,), dtype=torch.float16).npu() + x1_fp16 = torch.rand((N,), dtype=torch.float16).npu() + + # bfloat16 + x0_bf16 = torch.rand((N,), dtype=torch.bfloat16).npu() + x1_bf16 = torch.rand((N,), dtype=torch.bfloat16).npu() + + # int64 + x0_i64 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int64).npu() + x1_i64 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int64).npu() + + # int32 + x0_i32 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int32).npu() + x1_i32 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int32).npu() + + # int16 + x0_i16 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int16).npu() + x1_i16 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int16).npu() + + # int8 + x0_i8 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int8).npu() + x1_i8 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int8).npu() + + # bool (i1) + x0_i1 = torch.randint(low=0, high=2, size=(N,)).bool().npu() + x1_i1 = torch.randint(low=0, high=2, size=(N,)).bool().npu() + + test_cases = [ + ('fp32', x0_fp32, x1_fp32), + ('fp16', x0_fp16, x1_fp16), + ('bf16', x0_bf16, x1_bf16), + ('i64', x0_i64, x1_i64), + ('i32', x0_i32, x1_i32), + ('i16', x0_i16, x1_i16), + ('i8', x0_i8, x1_i8), + ('i1', x0_i1, x1_i1), + ] + + for dtype_name, x0, x1 in test_cases: + print(f"Running test for {dtype_name}...") + if dtype_name != 'i1': + if (test_case_is_inductor): + test_inductor_add(x0, x1) + else: + test_add(x0, x1) + else: + test_or(x0, x1) diff --git a/third_party/ascend/examples/tutorials/08-demo-libentry.py b/third_party/ascend/examples/tutorials/08-demo-libentry.py new file mode 100644 index 000000000..47f183722 --- /dev/null +++ b/third_party/ascend/examples/tutorials/08-demo-libentry.py @@ -0,0 +1,165 @@ +import time + +import numpy as np +import torch +import torch_npu + +import triton +import triton.language as tl +from triton.runtime.libentry import libentry + +device = torch.npu.current_device() +stream = torch.npu.current_stream(device) +stream_id = stream.npu_stream + + +def benchmark(func): + warmup = 10 + repeat = 100 + + def wrapper(*args, **kwargs): + # + for _ in range(warmup): + result = func(*args, **kwargs) + stream.synchronize() + # + start_time = time.perf_counter_ns() + for _ in range(repeat): + result = func(*args, **kwargs) + stream.synchronize() + end_time = time.perf_counter_ns() + # + start_time = start_time * 1e-3 + end_time = end_time * 1e-3 + elapsed_time = (end_time - start_time) / repeat + return (result, elapsed_time) + + return wrapper + + +def plot_performance_comparison(sizes, times_torch, times_triton, fname): + import matplotlib.pyplot as plt + + plt.rcParams["font.family"] = "Maple Mono NF CN" + plt.style.use('ggplot') + # + fig, ax = plt.subplots(figsize=(12, 8)) + ax.plot(sizes, times_torch, 'o-', label='Torch') + ax.plot(sizes, times_triton, 's-', label='Triton') + ax.set_title('Torch vs Triton Time Cost', fontsize=16) + ax.set_xlabel('Batch Size', fontsize=14) + ax.set_ylabel('Kernel Time (us)', fontsize=14) + ax.set_xlim([0, 2e4]) + ax.set_ylim([0, 500]) + ax.grid(True, linestyle='--', alpha=0.7) + ax.legend(fontsize=12) + plt.tight_layout() + fig.savefig(fname, dpi=300, bbox_inches='tight') + print(f"{fname} is saved") + + +def save_print_data(sizes, times_torch, times_triton, fname): + perf_data = np.zeros((len(sizes), 3)) + perf_data[:, 0] = sizes + perf_data[:, 1] = times_torch + perf_data[:, 2] = times_triton + np.savetxt(fname, perf_data, delimiter=",", header="batch, torch(us), triton(us)") + print("batch, torch(us), triton(us)") + for idx, size in enumerate(sizes): + print(f"{int(size)}, {times_torch[idx]}, {times_triton[idx]}") + + +def load_data(fname): + perf_data = np.loadtxt(fname, delimiter=",", skiprows=1) + sizes = perf_data[:, 0].astype(np.float32) + times_torch = perf_data[:, 1] + times_triton = perf_data[:, 2] + return sizes, times_torch, times_triton + + +@libentry() +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + XBLOCK: tl.constexpr, + XBLOCK_SUB: tl.constexpr, + RBLOCK: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * XBLOCK + rblk_idx = tl.arange(0, XBLOCK_SUB) + col_idx = tl.arange(0, RBLOCK) + for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): + row_offsets = row_start + row_idx + rblk_idx[:, None] + col_offsets = col_idx[None, :] + xmask = row_offsets < n_rows + ymask = col_offsets < n_cols + mask = xmask & ymask + input_idx = row_offsets * input_row_stride + col_offsets + input_ptrs = input_ptr + input_idx + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=1).reshape(XBLOCK_SUB, 1) + softmax_output = numerator / denominator + output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) + tl.store(output_ptrs, softmax_output, mask=mask) + + +@benchmark +def torch_func(x0: torch.Tensor): + m = torch.nn.Softmax(dim=1) + return m(x0) + + +@benchmark +def triton_func(y0: torch.Tensor, x0: torch.Tensor, stream_id0: int): + n_rows, n_cols = x0.shape + ncore = 40 + xs = (n_rows + ncore - 1) // ncore + xss = min(xs, 5) + softmax_kernel[(ncore, 1, 1)]( + y0, + x0, + x0.stride(0), + y0.stride(0), + n_rows, + n_cols, + XBLOCK=xs, + XBLOCK_SUB=xss, + RBLOCK=n_cols, + stream=stream_id0, + ) + return y0 + + +torch.manual_seed(0) +DEV = "npu" +DTYPE = torch.float32 +seq_len = 2 * 1024 + +batch_sizes = [] +torch_times = [] +triton_times = [] +for i in range(1, 16 + 1): + batch = i * 1000 + batch_sizes.append(batch) + x = torch.rand((batch, seq_len), dtype=DTYPE, device=DEV) + y = torch.empty_like(x) + torch_out, torch_time = torch_func(x) + triton_out, triton_time = triton_func(y, x, stream_id) + torch.testing.assert_close(triton_out, torch_out) + torch_times.append(torch_time) + triton_times.append(triton_time) + +data_fname = "compare_perf_softmax_triton_torch.csv" +save_print_data(batch_sizes, torch_times, triton_times, data_fname) +# In case of you already have csv file, you can directly run load_data(data_fname) +# to load the data. +figname = "compare_perf_softmax_triton_torch.png" +plot_performance_comparison(batch_sizes, torch_times, triton_times, figname) diff --git a/third_party/ascend/examples/tutorials/09-gather.py b/third_party/ascend/examples/tutorials/09-gather.py new file mode 100644 index 000000000..2a182cd83 --- /dev/null +++ b/third_party/ascend/examples/tutorials/09-gather.py @@ -0,0 +1,125 @@ +""" +Gather +=============== +This is an example only for npu. +""" + +import torch +import torch_npu +import triton +import triton.runtime.driver as driver +import triton.language as tl + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +# a torch-version gather benchmark +def torch_gather(embeddings, idxes, default_value=0.0): + # make the result tensor + res = torch.empty((idxes.shape[0], embeddings.shape[-1]), dtype=embeddings.dtype, device=embeddings.device) + + # scatter embeddings + res[idxes >= 0] = embeddings[idxes[idxes >= 0]] + # set default values + res[idxes < 0] = default_value + + return res + + +# triton-version gather's kernel +@triton.jit +def gather_kernel(embeddings_ptr, idxes_ptr, res_ptr, rows, cols, DEFAULT_VALUE: tl.constexpr, BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE_SUB: tl.constexpr): + SMALL_ROW_BLOCK_SIZE = BIG_ROW_BLOCK_SIZE - 1 + + embedding_dtype = embeddings_ptr.type.element_ty + default_value = tl.cast(DEFAULT_VALUE, dtype=embedding_dtype) + default_embedding = tl.full((COL_BLOCK_SIZE_SUB, ), default_value, dtype=embedding_dtype) + + core_idx = tl.program_id(0) + # compute the the size and start index of block + row_block_size = BIG_ROW_BLOCK_SIZE if (core_idx < BIG_CORE_NUM) else SMALL_ROW_BLOCK_SIZE + row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else (BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) + + # process blocks witn shape (row_block_size, COL_BLOCK_SIZE_SUB) one by one + for col_idx in tl.range(0, COL_BLOCK_SIZE, COL_BLOCK_SIZE_SUB): + emb_col_offsets = col_idx + tl.arange(0, COL_BLOCK_SIZE_SUB) + emb_col_mask = emb_col_offsets < cols + + for row_idx in tl.range(row_start_idx, min(row_start_idx + row_block_size, rows)): + idx_val = tl.load(idxes_ptr + row_idx) + + write_row_offset = row_idx * cols + write_emb_mask = emb_col_mask + + if idx_val >= 0: + read_row_offset = idx_val * cols + read_emb_mask = emb_col_mask + # read embedding + embedding = tl.load(embeddings_ptr + read_row_offset + emb_col_offsets, mask=read_emb_mask) + tl.store(res_ptr + write_row_offset + emb_col_offsets, embedding, write_emb_mask) + else: + # set default values + tl.store(res_ptr + write_row_offset + emb_col_offsets, default_embedding, write_emb_mask) + + +# triton-version gather's host +def triton_gather(embeddings: torch.Tensor, indices: torch.Tensor, default_value=0.0): + # constant settings for npu + USE_SIZE = 96 * 1024 + CORE_NUM = get_npu_properties()["num_vectorcore"] + + n_rows = indices.shape[0] + n_cols = embeddings.shape[1] + # make the result tensor + output = torch.empty(n_rows, n_cols, dtype=embeddings.dtype, device=embeddings.device) + + # when writing an npu kernel using triton, + # you should note that the difference between BLOCK_SIZE and BLOCK_SIZE_SUB + # BLOCK_SIZE specifies the size of data that are processed in one program + col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), 32) * 32 // embeddings.element_size() + # the data are scattered to multiple programs, which can not be even + # some process more data, some process less + big_row_block_size = triton.cdiv(n_rows, CORE_NUM) + big_core_num = CORE_NUM - ((big_row_block_size * CORE_NUM) - n_rows) + col_block_size = col_size_aligned + + # BLOCK_SIZE_SUB specifies the size of data that are processed in one loop of a program + max_col_block_size_sub = USE_SIZE // embeddings.element_size() // 2 + col_block_size_sub = min(col_size_aligned, max_col_block_size_sub) + + grid = (min(n_rows, CORE_NUM), triton.cdiv(n_cols, col_block_size)) + # launch the kernel + gather_kernel[grid](embeddings, indices, output, n_rows, n_cols, default_value, BIG_CORE_NUM=big_core_num, BIG_ROW_BLOCK_SIZE=big_row_block_size, COL_BLOCK_SIZE=col_block_size, COL_BLOCK_SIZE_SUB=col_block_size_sub) + + return output + + +if __name__ == "__main__": + for n_rows in (500, 1000): + for n_cols in (16, 17, 31, 32, 63, 64, 128, 256, 819, 512, 1024, 8192, 1001, 2003, 17000): + for index_num in (19, 123, 4321, 54321, 100, 200, 819, 500, 700, 1000): + print(n_rows, n_cols, index_num, flush=True) + + indices = torch.randint(0, n_rows, (index_num, ), dtype=torch.int32).npu() + embeddings = torch.randn(n_rows, n_cols, dtype=torch.float).npu() + + expect = torch_gather(embeddings, indices).cpu() + actual = triton_gather(embeddings, indices).cpu() + torch.npu.synchronize() + mask = ~(expect == actual) + + error_count = mask.sum().item() + total_count = mask.numel() + print("error rate:", error_count / total_count, flush=True) + + print("error detail:") + print("===========", flush=True) + print(expect[mask], flush=True) + print("===========", flush=True) + print(actual[mask], flush=True) + print("===========", flush=True) + print(flush=True) diff --git a/third_party/ascend/examples/tutorials/10-gather_sorted.py b/third_party/ascend/examples/tutorials/10-gather_sorted.py new file mode 100644 index 000000000..4c50e2259 --- /dev/null +++ b/third_party/ascend/examples/tutorials/10-gather_sorted.py @@ -0,0 +1,186 @@ +""" +Gather sorted +=============== +This is an example only for npu. +""" + +import torch +import torch_npu +import triton +import triton.runtime.driver as driver +import triton.language as tl + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +# a torch-version gather_sorted benchmark +def torch_gather_sorted(embeddings, sorted_idxes, aux_idxes): + # make the result tensor + res = torch.empty((aux_idxes.shape[0], embeddings.shape[-1]), dtype=embeddings.dtype, device=embeddings.device) + + # scatter embeddings + res[aux_idxes] = embeddings[sorted_idxes] + + return res + + +# triton-version gather_sorted's kernel +@triton.jit +def gather_sorted_kernel(embeddings_ptr, sorted_indices_ptr, aux_indices_ptr, res_ptr, rows, cols, DEFAULT_VALUE: tl.constexpr, BIG_CORE_NUM: tl.constexpr, BIG_ROW_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE_SUB: tl.constexpr): + SMALL_ROW_BLOCK_SIZE = BIG_ROW_BLOCK_SIZE - 1 + + emb_dtype = embeddings_ptr.type.element_ty + default_value = tl.cast(DEFAULT_VALUE, dtype=emb_dtype) + + core_idx = tl.program_id(0) + # compute the the size and start index of block + row_block_size = BIG_ROW_BLOCK_SIZE if (core_idx < BIG_CORE_NUM) else SMALL_ROW_BLOCK_SIZE + row_start_idx = (core_idx * BIG_ROW_BLOCK_SIZE) if (core_idx < BIG_CORE_NUM) else (BIG_CORE_NUM * BIG_ROW_BLOCK_SIZE + (core_idx - BIG_CORE_NUM) * SMALL_ROW_BLOCK_SIZE) + + # this version has 3-buffers, initilize for buffers + row_block_size_0 = tl.cdiv(row_block_size, 3) + remain_row_block_size = row_block_size - row_block_size_0 + row_block_size_1 = tl.cdiv(remain_row_block_size, 2) + row_block_size_2 = remain_row_block_size - row_block_size_1 + + row_start_idx_0 = row_start_idx + row_start_idx_1 = row_start_idx + row_block_size_0 + row_start_idx_2 = row_start_idx + row_block_size_0 + row_block_size_1 + + # process blocks witn shape (row_block_size, COL_BLOCK_SIZE_SUB) one by one + for col_idx in tl.range(0, COL_BLOCK_SIZE, COL_BLOCK_SIZE_SUB): + + embedding_0 = tl.full((COL_BLOCK_SIZE_SUB, ), default_value, dtype=emb_dtype) + embedding_1 = embedding_0 + 0 + embedding_2 = embedding_0 + 0 + + emb_offsets = col_idx + tl.arange(0, COL_BLOCK_SIZE_SUB) + emb_mask = emb_offsets < cols + + prev_embedding_idx_0 = tl.cast(-1, dtype=tl.int32) + prev_embedding_idx_1 = tl.cast(-1, dtype=tl.int32) + prev_embedding_idx_2 = tl.cast(-1, dtype=tl.int32) + for row_idx in tl.range(row_start_idx_0, row_start_idx_1): + # process the first buffer + embedding_idx_0 = tl.load(sorted_indices_ptr + row_idx) + res_idx_0 = tl.load(aux_indices_ptr + row_idx) + + if (embedding_idx_0 != 0) and (embedding_idx_0 != prev_embedding_idx_0): + embedding_0 = tl.load(embeddings_ptr + embedding_idx_0 * cols + emb_offsets, emb_mask) + tl.store(res_ptr + res_idx_0 * cols + emb_offsets, embedding_0, emb_mask) + else: + tl.store(res_ptr + res_idx_0 * cols + emb_offsets, embedding_0, emb_mask) + + prev_embedding_idx_0 = embedding_idx_0 + + # process the second buffer + if (row_idx + row_block_size_0) < (row_start_idx_1 + row_block_size_1): + embedding_idx_1 = tl.load(sorted_indices_ptr + row_idx + row_block_size_0) + res_idx_1 = tl.load(aux_indices_ptr + row_idx + row_block_size_0) + + if (embedding_idx_1 != 0) and (embedding_idx_1 != prev_embedding_idx_1): + embedding_1 = tl.load(embeddings_ptr + embedding_idx_1 * cols + emb_offsets, emb_mask) + tl.store(res_ptr + res_idx_1 * cols + emb_offsets, embedding_1, emb_mask) + else: + tl.store(res_ptr + res_idx_1 * cols + emb_offsets, embedding_1, emb_mask) + + prev_embedding_idx_1 = embedding_idx_1 + + # process the third buffer + if (row_idx + row_block_size_0 + row_block_size_1) < (row_start_idx_2 + row_block_size_2): + embedding_idx_2 = tl.load(sorted_indices_ptr + row_idx + row_block_size_0 + row_block_size_1) + res_idx_2 = tl.load(aux_indices_ptr + row_idx + row_block_size_0 + row_block_size_1) + + if (embedding_idx_2 != 0) and (embedding_idx_2 != prev_embedding_idx_2): + embedding_2 = tl.load(embeddings_ptr + embedding_idx_2 * cols + emb_offsets, emb_mask) + tl.store(res_ptr + res_idx_2 * cols + emb_offsets, embedding_2, emb_mask) + else: + tl.store(res_ptr + res_idx_2 * cols + emb_offsets, embedding_2, emb_mask) + + prev_embedding_idx_2 = embedding_idx_2 + + +# triton-version gather_sorted's host +def triton_gather_sorted(embeddings: torch.Tensor, sorted_indices: torch.Tensor, aux_indices: torch.Tensor, default_value=1.0): + # constant settings for npu + ALIGNED = 32 + USE_SIZE = 96 * 1024 + CORE_NUM = get_npu_properties()["num_vectorcore"] + + n_rows = sorted_indices.shape[0] + n_cols = embeddings.shape[1] + # make the result tensor + output = torch.empty(n_rows, n_cols, dtype=embeddings.dtype, device=embeddings.device) + + # when writing an npu kernel using triton, + # you should note that the difference between BLOCK_SIZE and BLOCK_SIZE_SUB + # BLOCK_SIZE specifies the size of data that are processed in one program + col_size_aligned = triton.cdiv(embeddings.shape[-1] * embeddings.element_size(), ALIGNED) * ALIGNED // embeddings.element_size() + # the data are scattered to multiple programs, which can not be even + # some process more data, some process less + big_row_block_size = triton.cdiv(n_rows, CORE_NUM) + big_core_num = CORE_NUM - ((big_row_block_size * CORE_NUM) - n_rows) + col_block_size = col_size_aligned + # BLOCK_SIZE_SUB specifies the size of data that are processed in one loop of a program + col_block_size_sub = min(1024, col_size_aligned) + + grid = (min(n_rows, CORE_NUM), triton.cdiv(n_cols, col_block_size)) + # launch the kernel + gather_sorted_kernel[grid](embeddings, sorted_indices, aux_indices, output, n_rows, n_cols, default_value, BIG_CORE_NUM=big_core_num, BIG_ROW_BLOCK_SIZE=big_row_block_size, COL_BLOCK_SIZE=col_block_size, COL_BLOCK_SIZE_SUB=col_block_size_sub) + + return output + + +# genreate the desired inputs +def generate_inputs(index_shape, table_shape, dtype): + sorted_indices = torch.randint(1, table_shape[0], index_shape, dtype=torch.int32).npu() + mask = torch.rand_like(sorted_indices, dtype=torch.float).npu() < 0.2 + + # make sorted_indices + sorted_indices[mask] = 0 + sorted_indices, _ = torch.sort(sorted_indices) + counts = torch.bincount(sorted_indices) + _, _indices = torch.sort(counts[sorted_indices], descending=True, stable=True) + sorted_indices = sorted_indices[_indices] + + # make aux_indicess + aux_indices = torch.arange(0, index_shape[0], dtype=torch.int32).npu() + _indices = torch.randperm(aux_indices.size(0)) + aux_indices = aux_indices[_indices] + + # make table, the first contains only 1.0 + table = torch.randn(table_shape, dtype=dtype).npu() + table[0] = 1.0 + + return table, sorted_indices, aux_indices + + +if __name__ == "__main__": + for table_rows in (500, 1000): + for table_cols in (16, 17, 31, 32, 63, 64, 128, 256, 819, 512, 1024, 8192, 1001, 2003, 17000): + for index_num in (19, 123, 4321, 54321, 100, 200, 819, 500, 700, 1000): + print(table_rows, table_cols, index_num, flush=True) + + table, sorted_indices, aux_indices = generate_inputs((index_num,), (table_rows, table_cols), torch.float) + + expect = torch_gather_sorted(table, sorted_indices, aux_indices).cpu() + torch.npu.synchronize() + actual = triton_gather_sorted(table, sorted_indices, aux_indices).cpu() + torch.npu.synchronize() + mask = ~(expect == actual) + + error_count = mask.sum().item() + total_count = mask.numel() + print("error rate:", error_count / total_count, flush=True) + + print("error detail:") + print("===========", flush=True) + print(expect[mask], flush=True) + print("===========", flush=True) + print(actual[mask], flush=True) + print("===========", flush=True) + print(flush=True) diff --git a/third_party/ascend/examples/tutorials/11-rab_time.py b/third_party/ascend/examples/tutorials/11-rab_time.py new file mode 100644 index 000000000..29ddbf4c3 --- /dev/null +++ b/third_party/ascend/examples/tutorials/11-rab_time.py @@ -0,0 +1,371 @@ +""" +Relative Attention Bias Timestamps +=============== +""" + +import math +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + +NUM_BUCKETS = 128 +BUCKET_DIVISOR = 0.301 + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +def create_pos_w(train_len: int, num_layers: int) -> torch.Tensor: + return torch.arange(0, 2 * train_len + 1).unsqueeze(1).repeat(1, num_layers) + + +def create_past_valid_lens(bs: int, past_len: int) -> torch.Tensor: + return torch.randint(0, past_len, (bs,)) + + +def create_timestamps( + train_len: int, candidate_len: int, past_valid_lens: torch.Tensor +) -> torch.Tensor: + bs = past_valid_lens.size(0) + timestamps = torch.zeros(bs, train_len + candidate_len // 2) + for i, valid_len in enumerate(past_valid_lens): + if valid_len > 0: + timestamps[i, :valid_len] = torch.arange(1, valid_len.int() + 1) + + if candidate_len <= 0: + return timestamps + timestamps[:, -candidate_len // 2:] = train_len + 1 + + return timestamps + + +def create_timestamps_weights(num_layers: int): + return ( + torch.arange(0, NUM_BUCKETS + 1) + .repeat(num_layers) + .reshape(NUM_BUCKETS + 1, num_layers) + ) + + +def create_rab_time_grad(num_layers: int, batchsize: int, s: int): + return torch.rand(num_layers, batchsize, s, s) * 1e-4 + + +def create_bucket_timestamps(batchsize: int, s: int): + result = torch.arange(batchsize * s) % NUM_BUCKETS + result = result.unsqueeze(-1).repeat(1, 1, s) + return result + + +@triton.jit +def rab_time_forward_kernel( + inp, + out, + index, + index_len: tl.constexpr, + inp_row_stride: tl.constexpr, + clamp_max: tl.constexpr, + bucketization_divisor: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + COL_BLOCK_SIZE: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + col_iter_num = tl.cdiv(BLOCK_SIZE, COL_BLOCK_SIZE) + + for col_idx in tl.range(0, col_iter_num): + cols_offsets = ( + pid0 * BLOCK_SIZE + col_idx * COL_BLOCK_SIZE + tl.arange(0, COL_BLOCK_SIZE) + ) + cols_mask = cols_offsets < index_len + + out_mask = cols_offsets < index_len + + index_val = tl.load(index + cols_offsets, mask=cols_mask, other=0.0) + index_val = tl.abs(index_val) + index_val = tl.minimum(tl.maximum(index_val, 1.0), clamp_max) + index_val = tl.log(index_val) + index_val = index_val / bucketization_divisor + index_val = tl.cast(index_val, tl.int64) + + inp_val = tl.load(inp + pid1 * inp_row_stride + tl.arange(0, inp_row_stride)) + out_val = tl.gather(inp_val, index_val, 0) + + tl.store(out + pid1 * index_len + cols_offsets, out_val, mask=out_mask) + + +def get_outer_loop_num(num_layers, index_len): + sub_num_layers = num_layers + while sub_num_layers * index_len >= 2**31 - 1: + sub_num_layers = sub_num_layers // 2 + outer_loop_num = (num_layers + sub_num_layers - 1) // sub_num_layers + remain_layers = num_layers % sub_num_layers + return outer_loop_num, sub_num_layers, remain_layers + + +def rab_time_forward_triton(ts_w, timestamps, bucketization_divisor): + ts_w_trans = ts_w.t().contiguous() + + bs, seq_len = timestamps.shape + infer_len = 2 * seq_len + num_layers = ts_w.shape[1] + num_buckets = ts_w.shape[0] - 1 + + timestamps_expanded = timestamps.unsqueeze(-1).repeat(1, 1, 2) + timestamps_expanded = timestamps_expanded.reshape( + bs, infer_len, 1 + ) - timestamps_expanded.reshape(bs, 1, infer_len) + + timestamps_expanded = timestamps_expanded.view(-1) + timestamps_expanded = timestamps_expanded.contiguous() + + clamp_max = torch.exp(torch.tensor(num_buckets * bucketization_divisor)).item() + index_len = bs * infer_len * infer_len + + out = torch.empty((num_layers, index_len), dtype=ts_w.dtype, device=ts_w.device) + outer_loop_num, sub_num_layers, remain_layers = get_outer_loop_num( + num_layers, index_len + ) + + CORE_NUM = get_npu_properties()["num_vectorcore"] + BLOCK_SIZE = math.ceil(index_len / CORE_NUM) + COL_BLOCK_SIZE = 8 * 1024 + + curr_layers = sub_num_layers + for i in range(outer_loop_num): + if i == outer_loop_num - 1 and remain_layers != 0: + curr_layers = remain_layers + grid = lambda meta: (triton.cdiv(index_len, meta["BLOCK_SIZE"]), curr_layers) + + rab_time_forward_kernel[grid]( + ts_w_trans[i * sub_num_layers], + out[i * sub_num_layers], + timestamps_expanded, + index_len, + num_buckets + 1, + clamp_max, + bucketization_divisor, + BLOCK_SIZE, + COL_BLOCK_SIZE, + ) + + out = out.view(num_layers, bs, infer_len, infer_len) + + return out + + +@triton.jit +def rab_time_backward_kernel( + inp, src, index, index_len, BLOCK_SIZE: tl.constexpr, COL_BLOCK_SIZE: tl.constexpr +): + pid0 = tl.program_id(axis=0) + total_col_num = ( + BLOCK_SIZE + if pid0 * BLOCK_SIZE + BLOCK_SIZE < index_len + else index_len - pid0 * BLOCK_SIZE + ) + COL_BLOCK_SIZE = min(COL_BLOCK_SIZE, total_col_num) + col_iter_num = (total_col_num + COL_BLOCK_SIZE - 1) // COL_BLOCK_SIZE + + for col_idx in tl.range(0, col_iter_num): + base_idx = 0 + base_idx = base_idx.to(index.dtype.element_ty) + + col_start_offset = col_idx * COL_BLOCK_SIZE + + acc_result = 0.0 + acc_result = acc_result.to(inp.dtype.element_ty) + cur_col_num = ( + COL_BLOCK_SIZE + if col_start_offset + COL_BLOCK_SIZE < total_col_num + else total_col_num - col_start_offset + ) + + for cur_idx in range(0, cur_col_num): + cur_offset = pid0 * BLOCK_SIZE + col_start_offset + cur_idx + + src_val = tl.load(src + cur_offset) + new_idx = tl.load(index + cur_offset) + + if base_idx == new_idx: + acc_result += src_val + else: + tl.atomic_add(inp + base_idx, acc_result) + + base_idx = new_idx + acc_result = 0.0 + acc_result = acc_result.to(inp.dtype.element_ty) + acc_result += src_val + + tl.atomic_add(inp + base_idx, acc_result) + + +def rab_time_backward_triton( + rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor +): + num_layers, b, s, _ = rab_time_grad.shape + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to( + rab_time_grad.device + ) + + bucket_timestamps_expand = ( + bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) + .repeat(1, 1, 2, 1, 2) + .reshape(b, s, s) + .to(torch.int64) + ).view(-1) + + index_len = bucket_timestamps_expand.numel() + + rab_time_grad_f32 = rab_time_grad.to(torch.float32) + sorted_bucket_timestamps_expand, sorted_idx = torch.sort( + bucket_timestamps_expand.view(-1) + ) + + torch.npu.synchronize() + + grid = lambda meta: (triton.cdiv(index_len, meta["BLOCK_SIZE"]),) + + CORE_NUM = get_npu_properties()["num_vectorcore"] + BLOCK_SIZE = math.ceil(index_len / CORE_NUM) + + COL_BLOCK_SIZE = 8 * 1024 + + for layer_idx in range(num_layers): + curr_sorted_grad_f32 = rab_time_grad_f32[layer_idx].view(-1)[sorted_idx] + rab_time_backward_kernel[grid]( + tsw_grad[layer_idx], + curr_sorted_grad_f32, + sorted_bucket_timestamps_expand, + index_len, + BLOCK_SIZE, + COL_BLOCK_SIZE, + ) + + return tsw_grad + + +def rab_time_forward_golden( + ts_w: torch.Tensor, timestamps: torch.Tensor, bucketization_divisor: float +) -> torch.Tensor: + """ + torch realization of rab time forward for reference. + """ + infer_len = timestamps.shape[1] * 2 + bs = timestamps.shape[0] + num_layers = ts_w.shape[1] + + timestamps = timestamps.unsqueeze(-1).repeat(1, 1, 2) + diff_timestamps = timestamps.reshape(bs, infer_len, 1) - timestamps.reshape( + bs, 1, infer_len + ) + + clamp_max = torch.exp(torch.tensor(NUM_BUCKETS * BUCKET_DIVISOR)) + diff_timestamps = ( + torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) + / bucketization_divisor + ) + bucket_timestamps = diff_timestamps.long() + bucket_timestamps = bucket_timestamps.view(-1) + result = torch.index_select(ts_w, dim=0, index=bucket_timestamps) + + result = result.t() + + result = result.view(num_layers, bs, infer_len, infer_len) + return result + + +def rab_time_backward_golden( + rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor +): + """ + torch realization of rab time backward for reference. + """ + num_layers, b, s, _ = rab_time_grad.shape + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to( + rab_time_grad.device + ) + + bucket_timestamps_expand = ( + bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) + .repeat(1, 1, 2, 1, 2) + .reshape(b, s, s) + .to(torch.int64) + ) + for n, grad in enumerate(rab_time_grad.to(torch.float32)): + tsw_grad[n] = tsw_grad[n].scatter_add( + src=grad.view(-1), index=bucket_timestamps_expand.view(-1), dim=0 + ) + return tsw_grad + + +def rab_time_forward_test(num_layers, train_len, candidate_len, bs, dtype): + past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) + timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to( + torch.int32 + ) + timestamps_weights = create_timestamps_weights(num_layers).to(dtype) + timestamps = timestamps.npu() + timestamps_weights = timestamps_weights.npu() + + torch_npu.npu.synchronize() + + # triton output + rab_time_out_triton = rab_time_forward_triton( + ts_w=timestamps_weights, + timestamps=timestamps, + bucketization_divisor=BUCKET_DIVISOR, + ) + torch_npu.npu.synchronize() + + # pytorch output + rab_time_out_golden = rab_time_forward_golden( + ts_w=timestamps_weights, + timestamps=timestamps, + bucketization_divisor=BUCKET_DIVISOR, + ) + torch_npu.npu.synchronize() + + torch.testing.assert_close(rab_time_out_triton, rab_time_out_golden) + print(f"test pass!") + + +def rab_time_backward_test(num_layers: int, batchsize: int, s: int, dtype: torch.dtype): + grad = create_rab_time_grad(num_layers, batchsize, s).to(dtype).npu() + bucket_timestamps = ( + create_bucket_timestamps(batchsize, s // 2).to(torch.int32).npu() + ) + + torch_npu.npu.synchronize() + + golden_result = ( + rab_time_backward_golden(grad, bucket_timestamps).to(torch.float32).cpu() + ) + op_result = ( + rab_time_backward_triton(grad, bucket_timestamps).to(torch.float32).cpu() + ) + + loss = 1e-4 if dtype == torch.float32 else 1e-3 + torch.testing.assert_close(op_result, golden_result, rtol=loss, atol=loss) + print(f"test pass!") + + +if __name__ == "__main__": + num_layers = 8 + train_len = 500 + candidate_len = 500 + batch_size = 4 + data_type = torch.float32 + print("running rab time forward test:") + rab_time_forward_test(num_layers, train_len, candidate_len, batch_size, data_type) + + print("running rab time backward test:") + rab_time_backward_test( + num_layers, batch_size, 2 * train_len + candidate_len, data_type + ) diff --git a/third_party/ascend/examples/tutorials/12-hstu_attention.py b/third_party/ascend/examples/tutorials/12-hstu_attention.py new file mode 100644 index 000000000..6cdea48d7 --- /dev/null +++ b/third_party/ascend/examples/tutorials/12-hstu_attention.py @@ -0,0 +1,763 @@ +""" +HSTU Attention +=============== +""" + +from typing import List, Optional, Tuple +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver +import numpy as np +import torch.nn.functional as F + +DEVICE = "npu" +BLOCK_FWD = 64 +BLOCK_BWD = 32 + + +def get_npu_properties(coreType): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device)[coreType] + + +@triton.jit +def _hstu_attn_fwd_one_block( + q, + k_block_ptr, + v_block_ptr, + bias_block_ptr, + alpha, + MAX_SEQ_LEN, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + mask_block, +): + k = tl.load(k_block_ptr) + qk = tl.dot(q, tl.trans(k)) * alpha + if HAS_BIAS: + rel_attn_bias = tl.load(bias_block_ptr) + qk = qk + rel_attn_bias + silu = qk / (1.0 + tl.exp(-qk)) * (1.0 / MAX_SEQ_LEN) + if CAUSAL: + silu = tl.where(mask_block, silu, 0) + v = tl.load(v_block_ptr) + silu = silu.to(v.dtype) + return tl.dot(silu, v) + + +@triton.jit +def _hstu_attn_fwd_compute( # noqa C901 + Q, + K, + V, + seq_offsets, + Out, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vh: tl.constexpr, + stride_om: tl.constexpr, + stride_oh: tl.constexpr, + alpha, + head_num, + MAX_SEQ_LEN, + off_batch, + off_head, + start_m, + seq_start, + seq_len, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + mask_block, + bias, +): + off_head = off_head.to(tl.int64) + off_seq = seq_start.to(tl.int64) + start_m = start_m.to(tl.int32) + + # initialize offsets + q_offset = off_seq * stride_qm + off_head * stride_qh + k_offset = off_seq * stride_kn + off_head * stride_kh + v_offset = off_seq * stride_vn + off_head * stride_kh + + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) + k_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_kn, 1), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_D_Q), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seq_len, BLOCK_D_V), + strides=(stride_vn, 1), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_D_V), + order=(1, 0), + ) + q = tl.load(Q_block_ptr) + + acc = tl.zeros([BLOCK_M, BLOCK_D_V], dtype=tl.float32) + if CAUSAL: + low = 0 + high = start_m + BLOCK_M + else: + low = 0 + high = seq_len + + bias_block_ptr = None + if HAS_BIAS: + bias_block_ptr = tl.make_block_ptr( + base=bias + off_batch * head_num * MAX_SEQ_LEN * MAX_SEQ_LEN + off_head * MAX_SEQ_LEN * MAX_SEQ_LEN, + shape=(MAX_SEQ_LEN, MAX_SEQ_LEN), + strides=(MAX_SEQ_LEN, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + for start_n in range(low, high, BLOCK_N): + acc += _hstu_attn_fwd_one_block( + q=q, + k_block_ptr=k_block_ptr, + v_block_ptr=v_block_ptr, + bias_block_ptr=bias_block_ptr, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + CAUSAL=CAUSAL and start_m == start_n, + HAS_BIAS=HAS_BIAS, + mask_block=mask_block, + ) + k_block_ptr = tl.advance(k_block_ptr, (BLOCK_N, 0)) + v_block_ptr = tl.advance(v_block_ptr, (BLOCK_N, 0)) + if HAS_BIAS: + bias_block_ptr = tl.advance(bias_block_ptr, (0, BLOCK_N)) + + # rematerialize offsets to save registers + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + off_seq * stride_om + off_head * stride_oh + offs_m = start_m + tl.arange(0, BLOCK_M) + out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + + +@triton.jit +def _hstu_attn_fwd( # noqa C901 + Q, + K, + V, + seq_offsets, + Out, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vh: tl.constexpr, + stride_om: tl.constexpr, + stride_oh: tl.constexpr, + alpha: tl.constexpr, + batch: tl.constexpr, + head_num: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + head_dim: tl.constexpr, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + CORE_NUM: tl.constexpr, + tasks: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + mask, + bias, +): + core_id = tl.program_id(0) + cur_batch = 0 + mask_block = None + if CAUSAL and mask is not None: + mask_ptr = tl.make_block_ptr( + base=mask, + shape=(MAX_SEQ_LEN, MAX_SEQ_LEN), + strides=(MAX_SEQ_LEN, 1), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_M), + order=(1, 0), + ) + mask_block = tl.load(mask_ptr) + for col in range(core_id, tasks, CORE_NUM): + seq_end = tl.load(seq_offsets + cur_batch + 1) + start_m = col * BLOCK_M + while start_m >= seq_end * head_num // 2: + cur_batch += 1 + seq_end = tl.load(seq_offsets + cur_batch + 1) + seq_start = tl.load(seq_offsets + cur_batch) + seq_len = seq_end - seq_start + off_batch = cur_batch + off_head = (start_m - seq_start * head_num // 2) // (seq_len // 2) + start_m_1 = (start_m - seq_start * head_num // 2) % (seq_len // 2) + start_m_2 = seq_len - start_m_1 - BLOCK_M + _hstu_attn_fwd_compute(Q, K, V, seq_offsets, Out, stride_qm, stride_qh, stride_kn, stride_kh, + stride_vn, stride_vh, stride_om, stride_oh, alpha, head_num, MAX_SEQ_LEN, off_batch, off_head, + start_m_1, seq_start, seq_len, CAUSAL, HAS_BIAS, head_dim, head_dim, BLOCK_M, BLOCK_N, + mask_block=mask_block, + bias=bias, + ) + _hstu_attn_fwd_compute(Q, K, V, seq_offsets, Out, stride_qm, stride_qh, stride_kn, stride_kh, + stride_vn, stride_vh, stride_om, stride_oh, alpha, head_num, MAX_SEQ_LEN, off_batch, off_head, + start_m_2, seq_start, seq_len, CAUSAL, HAS_BIAS, head_dim, head_dim, BLOCK_M, BLOCK_N, + mask_block=mask_block, + bias=bias, + ) + + +@triton.jit +def _hstu_attn_bwd_one_block( # noqa C901 + start_m, + offs_n, + offs_m, + q_ptrs, + dq_ptrs, + mask_n, + do_ptrs, + dk, + dv, + k, + v, + pos_offs_n, + seq_len, + max_ids, + stride_qm, + stride_dom, + stride_dqm, + alpha, + MAX_SEQ_LEN, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + bias_block_ptr, +): + pos_offs_m = offs_m + start_m + mask_m = pos_offs_m < seq_len + # recompute qk and silu + q = tl.load( + q_ptrs + start_m * stride_qm, + mask=mask_m[:, None], + other=0.0, + ) + q_trans = tl.trans(q) + qk_trans = tl.dot(k, q_trans) * alpha + if HAS_BIAS: + rel_attn_bias = tl.load(bias_block_ptr) + qk_trans = qk_trans + tl.trans(rel_attn_bias) + sig_trans = 1.0 / (1.0 + tl.exp(-qk_trans)) + silu_trans = qk_trans * sig_trans * (1.0 / MAX_SEQ_LEN) + if CAUSAL: + invalid_mask_trans = pos_offs_m[None, :] == offs_n[:, None] + pos_offs_m_minus_n = pos_offs_m[None, :] - pos_offs_n[:, None] + invalid_mask_trans = invalid_mask_trans | (pos_offs_m_minus_n > 0) + silu_trans = tl.where(invalid_mask_trans, silu_trans, 0) + silu_trans = silu_trans.to(k.dtype) + # compute dv + do = tl.load( + do_ptrs + start_m * stride_dom, + mask=mask_m[:, None], + other=0.0, + ) + dv += tl.dot(silu_trans, do) + # compute dk and dq (dqk = do * v^T dk = dqk^T * q dq = dqk * k) + dqk_trans = tl.dot(v, tl.trans(do)) + dqk_trans = dqk_trans * sig_trans * (1 + qk_trans * (1 - sig_trans)) * (1.0 / MAX_SEQ_LEN) + if CAUSAL: + dqk_trans = tl.where(invalid_mask_trans, dqk_trans, 0) + dqk_trans = dqk_trans.to(k.dtype) + dq = tl.load( + dq_ptrs + start_m * stride_dqm, + mask=mask_m[:, None], + other=0.0, + ) + dq += tl.dot(tl.trans(dqk_trans), k) * alpha + tl.store( + dq_ptrs + start_m * stride_dqm, + dq, + mask=mask_m[:, None], + ) + # Note: the factor `alpha` is delayed until the end of the function to reduce the cost + dk += tl.dot(dqk_trans, q) + return dk, dv + + +@triton.jit +def _hstu_attn_bwd_one_col_block( # noqa C901 + start_n, + seq_len, + Q, + K, + V, + DOut, + DQ, + DK, + DV, + stride_qm, + stride_kn, + stride_vn, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + alpha, + MAX_SEQ_LEN, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + bias, +): + # Work on the subsequence dv[start_n, start_n + BLOCK_N, :] + if CAUSAL: + low = start_n + high = seq_len + else: + low = 0 + high = seq_len + + # initialize row/col offsets + offs_m = tl.arange(0, BLOCK_M) + offs_qk_d = tl.arange(0, BLOCK_D_Q) + offs_v_d = tl.arange(0, BLOCK_D_V) + offs_n = start_n + tl.arange(0, BLOCK_N) + + dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_qk_d[None, :]) + dk = tl.zeros([BLOCK_N, BLOCK_D_Q], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_V], dtype=tl.float32) + + mask_n = offs_n < seq_len + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_qk_d[None, :]) + do_ptrs = DOut + (offs_m[:, None] * stride_dom + offs_v_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_qk_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_v_d[None, :]) + k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) + v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) + + max_ids = seq_len + pos_offs_n = offs_n + # loop over rows + for start_m in tl.range(low, high, BLOCK_M): + bias_block_ptr = None + if HAS_BIAS: + bias_block_ptr = tl.make_block_ptr( + base=bias, + shape=(MAX_SEQ_LEN, MAX_SEQ_LEN), + strides=(MAX_SEQ_LEN, 1), + offsets=(start_m, start_n), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + start_m = tl.multiple_of(start_m, BLOCK_M) + dk, dv = _hstu_attn_bwd_one_block( + start_m=start_m, + offs_n=offs_n, + offs_m=offs_m, + q_ptrs=q_ptrs, + dq_ptrs=dq_ptrs, + mask_n=mask_n, + do_ptrs=do_ptrs, + dk=dk, + dv=dv, + k=k, + v=v, + pos_offs_n=pos_offs_n, + seq_len=seq_len, + max_ids=max_ids, + stride_qm=stride_qm, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + CAUSAL=CAUSAL, + HAS_BIAS=HAS_BIAS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + bias_block_ptr=bias_block_ptr, + ) + # write-back + dk = dk * alpha + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_v_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_qk_d[None, :]) + tl.store(dv_ptrs, dv.to(k.dtype), mask=mask_n[:, None]) + tl.store(dk_ptrs, dk.to(k.dtype), mask=mask_n[:, None]) + + +@triton.jit +def _hstu_attn_bwd( # noqa C901 + Q, K, V, Grad, DQ, DK, DV, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_kn: tl.constexpr, + stride_kh: tl.constexpr, + stride_vn: tl.constexpr, + stride_vh: tl.constexpr, + stride_dom: tl.constexpr, + stride_doh: tl.constexpr, + seq_offsets, + alpha: tl.constexpr, + batch: tl.constexpr, + head_num: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + head_dim: tl.constexpr, + CAUSAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + bias, +): + off = tl.program_id(0) + off_batch = off // head_num + off_head = off % head_num + off_head = off_head.to(tl.int64) + seq_start = tl.load(seq_offsets + off_batch).to(tl.int64) + seq_end = tl.load(seq_offsets + off_batch + 1) + seq_len = (seq_end - seq_start).to(tl.int32) + # offset pointers for batch/head + q_offset = seq_start * stride_qm + off_head * stride_qh + k_offset = seq_start * stride_kn + off_head * stride_kh + v_offset = seq_start * stride_vn + off_head * stride_vh + grad_offset = seq_start * stride_dom + off_head * stride_doh + bias_offset = off_batch * head_num * MAX_SEQ_LEN * MAX_SEQ_LEN + off_head * MAX_SEQ_LEN * MAX_SEQ_LEN + for start_n in range(0, seq_len, BLOCK_N): + _hstu_attn_bwd_one_col_block( + start_n=start_n, + seq_len=seq_len, + Q=Q + q_offset, + K=K + k_offset, + V=V + v_offset, + DOut=Grad + grad_offset, + DQ=DQ + q_offset, + DK=DK + k_offset, + DV=DV + v_offset, + stride_qm=stride_qm, + stride_kn=stride_kn, + stride_vn=stride_vn, + stride_dom=stride_dom, + stride_dqm=stride_qm, + stride_dkn=stride_kn, + stride_dvn=stride_vn, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + CAUSAL=CAUSAL, + HAS_BIAS=HAS_BIAS, + BLOCK_D_Q=head_dim, + BLOCK_D_V=head_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + bias=bias + bias_offset if HAS_BIAS else bias, + ) + + +def triton_hstu_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + alpha: float, + causal: bool, + mask: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch = seq_offsets.numel() - 1 + total_seq, head_num, head_dim = q.shape + out = torch.empty_like(v) + BLOCK_M = BLOCK_FWD + BLOCK_N = BLOCK_FWD + if total_seq == 0: + print("error") + return out + has_bias = bias is not None + core_num = get_npu_properties('num_aicore') + tasks = total_seq * head_num // BLOCK_M // 2 + grid = (core_num, 1, 1) + _hstu_attn_fwd[grid](q, k, v, seq_offsets, out, q.stride(0), q.stride(1), k.stride(0), k.stride(1), + v.stride(0), v.stride(1), out.stride(0), out.stride(1), alpha, batch, head_num, max_seq_len, head_dim, + causal, has_bias, core_num, tasks, BLOCK_M, BLOCK_N, mask, bias, + ) + return out + + +def triton_hstu_attention_bwd( + grad: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + alpha: float, + causal: bool, + bias: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + if grad.shape[0] == 0: + return dq, dk, dv + batch = seq_offsets.numel() - 1 + _, head_num, head_dim = q.shape + has_bias = bias is not None + grid = (batch * head_num, 1,) + _hstu_attn_bwd[grid](q, k, v, grad, dq, dk, dv, + q.stride(0), q.stride(1), k.stride(0), k.stride(1), v.stride(0), v.stride(1), + grad.stride(0), grad.stride(1), seq_offsets, alpha, batch, head_num, max_seq_len, head_dim, + causal, has_bias, BLOCK_BWD, BLOCK_BWD, bias, + ) + return dq, dk, dv + + +def jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, dataType): + seq_array = np.arange(256, max_seq_len + 1, 256) + seq_lens = np.random.choice(seq_array, size=batch_size) + if not np.isin(max_seq_len, seq_lens): + seq_lens[np.random.randint(0, batch_size)] = max_seq_len + seq_offset = torch.concat((torch.zeros((1,), dtype=torch.int64), \ + torch.cumsum(torch.from_numpy(seq_lens), axis=0))).to(torch.int64).numpy() + max_seq_len = np.max(seq_lens) + total_seqs = np.sum(seq_lens) + grad = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + q = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + k = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + v = torch.rand((int(total_seqs), num_heads, attention_dim), dtype=dataType) + print("batch_size:", batch_size, ", max_seq_len :", max_seq_len, ", head_nums:", num_heads, ", head_dim:", attention_dim) + print("total_seqs:", total_seqs, "\nseq_lens:", seq_lens, "\nseq_offset:", seq_offset) + + bias = torch.empty(batch_size, num_heads, max_seq_len, max_seq_len, dtype=data_type).uniform_(-1, 1) + mask = 1 - torch.triu(torch.ones(batch_size, num_heads, max_seq_len, max_seq_len), diagonal=1).to(torch.float32) + return grad, q, k, v, bias, mask, max_seq_len, seq_offset + + +def dense_to_jagged(q, dense_tensor, seq_lens): + tensor = torch.zeros_like(q) + offset = 0 + for batch_id, seq_len in enumerate(seq_lens): + tensor[offset: offset + seq_len, :, :] = dense_tensor[batch_id, 0: seq_len, :, :] + offset = offset + seq_len + return tensor + + +def jagged_to_dense(jagged_tensor, seq_lens, head_nums, atten_dim): + need_pad_seq = [] + offset = 0 + for _, seq_len in enumerate(seq_lens): + src_tensor = jagged_tensor[offset: offset + seq_len, :, :].reshape(seq_len, head_nums, atten_dim) + need_pad_seq.append(src_tensor) + offset = offset + seq_len + + dense_tensor = torch.nn.utils.rnn.pad_sequence(need_pad_seq, batch_first=True) + return dense_tensor + + +def gloden_fwd(q, k, v, mask, alpha, seq_offset, attnBias, max_seq_len, enable_mask, enableBias, dataType): + head_nums = q.shape[1] + head_dim = q.shape[2] + batch_size = attnBias.shape[0] + seq_lens = np.zeros((batch_size, )).astype(np.int64) + for batch_id in range(batch_size): + seq_lens[batch_id] = seq_offset[batch_id + 1] - seq_offset[batch_id] + q_dens = jagged_to_dense(q, seq_lens, head_nums, head_dim).to(dataType) + k_dens = jagged_to_dense(k, seq_lens, head_nums, head_dim).to(dataType) + v_dens = jagged_to_dense(v, seq_lens, head_nums, head_dim).to(dataType) + q_dens = q_dens.permute(0, 2, 1, 3) + k_dens = k_dens.permute(0, 2, 3, 1) + v_dens = v_dens.permute(0, 2, 1, 3) + + qk_attn = torch.matmul(q_dens, k_dens) * alpha + qk_attn = qk_attn.to(torch.float32) + attnBias = attnBias.to(torch.float32) + mask = mask.to(torch.float32) + if enableBias: + qk_attn = qk_attn + attnBias + silu = F.silu(qk_attn) * (1 / max_seq_len) + if enable_mask: + silu = silu * mask + silu = silu.to(dataType) + atten_output = torch.matmul(silu, v_dens) + + atten_output = atten_output.permute(0, 2, 1, 3) + atten_output = dense_to_jagged(q, atten_output, seq_lens) + return atten_output.to(dataType) + + +def test_fwd(batch_size, max_seq_len, num_heads, attention_dim, data_type): + alpha = 1 # 0.5 + _, q, k, v, bias, mask, max_seq_len, seq_offset = jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, data_type) + # golden 输出 + golden_output = gloden_fwd(q, k, v, mask, alpha, seq_offset, bias, max_seq_len, True, False, data_type) + # triton 输出 + seq_offsets = torch.tensor(seq_offset, dtype=torch.int64, device=DEVICE) + triton_output = triton_hstu_attention_fwd( + q=q.npu(), + k=k.npu(), + v=v.npu(), + seq_offsets=seq_offsets, + max_seq_len=int(max_seq_len), + alpha=alpha, + causal=True, + mask=mask.npu(), + ) + loss = 1e-4 + if data_type == torch.float16: + loss = 1e-3 + elif data_type == torch.bfloat16: + loss = 1e-2 + compare_result = torch.allclose(triton_output.cpu(), golden_output, loss, loss) + + if compare_result: + write_result = 'ACC PASS' + else: + write_result = 'ACC FAIL' + print(f"compare result: {write_result}") + + +def golden_bwd(grad, q, k, v, bias, mask, max_seq_len, seq_offset, enable_mask, silu_scale, enable_bias, data_type): + def jagged_to_dense_bwd(jagged_tensor, seq_lens, max_seq_len, head_num, head_dim): + batch_size = len(seq_lens) + dense_tensor = torch.zeros(batch_size, max_seq_len, head_num, head_dim, dtype=jagged_tensor.dtype) + + offset = 0 + for batch_id, seq_len in enumerate(seq_lens): + dense_tensor[batch_id, :seq_len, :, :] = jagged_tensor[offset: offset + seq_len, :, :] + offset = offset + seq_len + + return dense_tensor + + def dense_to_jagged_bwd(jagged_tensor, dense_tensor, seq_lens): + tensor = torch.zeros_like(jagged_tensor) + + offset = 0 + for batch_id, seq_len in enumerate(seq_lens): + tensor[offset: offset + seq_len, :, :] = dense_tensor[batch_id, 0: seq_len, :, :] + offset = offset + seq_len + + return tensor + + q = q.cpu() + k = k.cpu() + v = v.cpu() + grad = grad.cpu() + head_nums = grad.shape[1] + head_dim = grad.shape[2] + batch_size = bias.shape[0] + seq_lens = np.zeros((batch_size,)).astype(np.int64) + for batch_id in range(batch_size): + seq_lens[batch_id] = seq_offset[batch_id + 1] - seq_offset[batch_id] + grad_dens = jagged_to_dense_bwd(grad, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + q_dens = jagged_to_dense_bwd(q, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + k_dens = jagged_to_dense_bwd(k, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + v_dens = jagged_to_dense_bwd(v, seq_lens, max_seq_len, head_nums, head_dim).to(data_type) + actual_seq_lens = torch.from_numpy(seq_lens).reshape(batch_size, 1, 1, 1).to(data_type) + actual_seq_lens = torch.broadcast_to(actual_seq_lens, bias.shape) + qk = torch.matmul(q_dens.permute(0, 2, 1, 3), k_dens.permute(0, 2, 3, 1)) + gv = torch.matmul(grad_dens.permute(0, 2, 1, 3), v_dens.permute(0, 2, 3, 1)) + qk = qk.float() + gv = gv.float() + bias = bias.float() + if enable_mask: + mask = mask.to(data_type) + mask = mask.float() + if enable_bias: + bias = bias.to(data_type) + bias = bias.float() + qkb = qk + bias + else: + qkb = qk + real_silu_scale = 1 / max_seq_len if silu_scale == 0.0 else silu_scale + + if enable_mask: + score = F.silu(qkb) * real_silu_scale * mask + else: + score = F.silu(qkb) * real_silu_scale + score = score.to(data_type) + v_grad_dens = torch.matmul(score.permute(0, 1, 3, 2), grad_dens.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) + if enable_mask: + bias_grad = gv * real_silu_scale * mask * F.sigmoid(qkb) * (1 + qkb * (1 - F.sigmoid(qkb))) + else: + bias_grad = gv * real_silu_scale * F.sigmoid(qkb) * (1 + qkb * (1 - F.sigmoid(qkb))) + bias_grad = bias_grad.to(data_type) + k_grad_dens = torch.matmul(bias_grad.permute(0, 1, 3, 2), q_dens.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) + q_grad_dens = torch.matmul(bias_grad, k_dens.permute(0, 2, 1, 3)).permute(0, 2, 1, 3) + bias_grad = bias_grad.cpu() + q_grad_dens = q_grad_dens.cpu() + q_grad = dense_to_jagged_bwd(q, q_grad_dens, seq_lens) + k_grad_dens = k_grad_dens.cpu() + k_grad = dense_to_jagged_bwd(k, k_grad_dens, seq_lens) + v_grad_dens = v_grad_dens.cpu() + v_grad = dense_to_jagged_bwd(v, v_grad_dens, seq_lens) + torch.npu.synchronize() + return q_grad, k_grad, v_grad, bias_grad + + +def test_bwd(batch_size, max_seq_len, num_heads, attention_dim, data_type): + alpha = 1 # 0.5 + grad, q, k, v, bias, mask, max_seq_len, seq_offset = jagged_data_gen(batch_size, max_seq_len, num_heads, attention_dim, data_type) + # golden 输出 + q_grad_golden, k_grad_golden, v_grad_golden, _ = golden_bwd(grad, q, k, v, bias, mask, + max_seq_len, seq_offset, True, 0, False, data_type) + + # triton 输出 + seq_offsets = torch.tensor(seq_offset, dtype=torch.int64, device=DEVICE) + dq, dk, dv = triton_hstu_attention_bwd( + grad=grad.npu(), + q=q.npu(), + k=k.npu(), + v=v.npu(), + seq_offsets=seq_offsets, + max_seq_len=int(max_seq_len), + alpha=alpha, + causal=True, + ) + loss = 1e-4 + if data_type == torch.float16: + loss = 1e-3 + elif data_type == torch.bfloat16: + loss = 1e-2 + q_res = torch.allclose(dq.cpu(), q_grad_golden.cpu(), loss, loss) + k_res = torch.allclose(dk.cpu(), k_grad_golden.cpu(), loss, loss) + v_res = torch.allclose(dv.cpu(), v_grad_golden.cpu(), loss, loss) + if q_res and k_res and v_res: + write_result = 'ACC PASS' + else: + write_result = 'ACC FAIL' + print("dq res : ", q_res, " dk res : ", k_res, " dv res : ", v_res) + print(f"compare result: {write_result}") + + +if __name__ == "__main__": + #取值范围: 1~2048 + batch_size = 2 + #256的倍数,范围:[256,4096]; + max_seq_len = 1024 + #取值: 2/4/6/8 + num_heads = 2 + #取值: 32/64/128/256 + attention_dim = 32 + data_type = torch.float32 + print("Running hstu attention forward test:") + test_fwd(batch_size, max_seq_len, num_heads, attention_dim, data_type) + print("Running hstu attention backward test:") + test_bwd(batch_size, max_seq_len, num_heads, attention_dim, data_type) \ No newline at end of file diff --git a/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py b/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py new file mode 100644 index 000000000..e1c0d223a --- /dev/null +++ b/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py @@ -0,0 +1,214 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +@triton.autotune( +configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + ], + key=["M", "N", "K"] +) +@triton.jit +def matmul_kernel( + mat_a, mat_b, mat_c, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + num_cores: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_TRESHHOLD: tl.constexpr, +): + pid = tl.program_id(axis=0) + task_m_idx = 0 + task_n_idx = 0 + + ''' + 水平分核方式每个任务块编号如下 + [0, 1, 2, 3, 4, 5, 6, 7] + [8, 9, 10, 11, 12, 13, 14, 15] + [16, 17, 18, 19, 20, 21, 22, 23] + [24, 25, 26, 27, 28, 29, 30, 31] + [32, 33, 34, 35, 36, 37, 38, 39] + [40, 41, 42, 43, 44, 45, 46, 47] + [48, 49, 50, 51, 52, 53, 54, 55] + [56, 57, 58, 59, 60, 61, 62, 63] + 0核处理 0 20 40 60 4块任务 + 1核处理 1 21 41 61 4块任务 + 2核处理 2 22 42 62 4块任务 + ... + 19核处理 19 39 59 3块任务 + + 大shape下如果使用传统水平分核方式,会有如下问题 + 1:同一时间大量核心需要访问同一块左矩阵内存,产生Bank冲突,导致硬件访问效率降低 + 2:当完成一整行mat_c运算时,已经将所有右矩阵数据全部使用上,右矩阵较大时会超过L2Cache的容量上限, + 从而导致L2Cache的搬入及换出,此后每行运算都会或多或少产生CacheMiss,导致L2Cche命中率较低,影响 + 算子执行效率 + 此处使用8 * 8对角线分核方式可以按8 * 8的方块沿对角线方向分核计算,可以很大程度优化上面两点。 + + 此处以8*8对角线分核为例,实际以BLOCK_TRESHHOLD为tune参数选择最优的阈值 + 8 * 8 对角线分核方式中,每8 * 8分格内任务块编号如下 + [0, 8, 16, 24, 32, 40, 48, 56] + [57, 1, 9, 17, 25, 33, 41, 49] + [50, 58, 2, 10, 18, 26, 34, 42] + [43, 51, 59, 3, 11, 19, 27, 35] + [36, 44, 52, 60, 4, 12, 20, 28] + [29, 37, 45, 53, 61, 5, 13, 21] + [22, 30, 38, 46, 54, 62, 6, 14] + [15, 23, 31, 39, 47, 55, 63, 7] + + M轴方向超过8个基本块时,使用对角线分核可以明显减小Bank冲突 + 当右矩阵大小超过L2Cache大小时,采取对角线分核可以提升L2Cache利用率 + 所以当矩阵在M和N方向均超过8块时使能对角线分核即可有优化,当右矩阵大小超过L2Cache大小时优化效果尤为明显 + ''' + NUM_BLOCKS_M = triton.cdiv(M, BLOCK_M) + NUM_BLOCKS_N = triton.cdiv(N, BLOCK_N) + NUM_BLOCKS = NUM_BLOCKS_M * NUM_BLOCKS_N + #当任务量较多时,可以使能对角线分核策略进行优化 + if NUM_BLOCKS_M >= BLOCK_TRESHHOLD and NUM_BLOCKS_N >= BLOCK_TRESHHOLD: + for block_idx in range ( + pid, NUM_BLOCKS, num_cores + ): + #8 * 8 对角线分核代码实现 + curThresholdM = BLOCK_TRESHHOLD if block_idx < (NUM_BLOCKS_M // BLOCK_TRESHHOLD * BLOCK_TRESHHOLD) * NUM_BLOCKS_N else NUM_BLOCKS_M % BLOCK_TRESHHOLD + curThresholdM_thresholdN = curThresholdM * BLOCK_TRESHHOLD + curThresholdN = BLOCK_TRESHHOLD if block_idx % (NUM_BLOCKS_N * BLOCK_TRESHHOLD) < (curThresholdM * NUM_BLOCKS_N) // curThresholdM_thresholdN * curThresholdM_thresholdN else NUM_BLOCKS_N % BLOCK_TRESHHOLD + localRelativeBlock = block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) % (BLOCK_TRESHHOLD * curThresholdM) + task_m_idx = localRelativeBlock % curThresholdM + block_idx // (BLOCK_TRESHHOLD * NUM_BLOCKS_N) * BLOCK_TRESHHOLD + #求最小公倍数,方便求基本块的坐标 + x, y = curThresholdM, curThresholdN if curThresholdM > curThresholdN else curThresholdN, curThresholdM + while y != 0: + x, y = y, x % y + lcm = curThresholdM * curThresholdN // x + task_n_idx = (localRelativeBlock + (localRelativeBlock // lcm)) % curThresholdN + block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) // curThresholdM_thresholdN * BLOCK_TRESHHOLD + + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N),dtype = tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( + k_start + tl.arange(0, BLOCK_K) + )[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K + )[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask = mat_a_mask, other = 0.0) # @hint: dot_pad_only_k + mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask = mat_b_mask, other = 0.0) # @hint: dot_pad_only_k + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask = mat_c_mask) + else: + #传统顺序分核 + for block_idx in range ( + pid, NUM_BLOCKS, num_cores + ): + task_m_idx = block_idx // NUM_BLOCKS_N + task_n_idx = block_idx % NUM_BLOCKS_N + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N),dtype = tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( + k_start + tl.arange(0, BLOCK_K) + )[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K + )[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask = mat_a_mask, other = 0.0) # @hint: dot_pad_only_k + mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask = mat_b_mask, other = 0.0) # @hint: dot_pad_only_k + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask = mat_c_mask) + +def triton_matmul( + mat_a, + mat_b, +): + m = mat_a.shape[0] + k = mat_a.shape[1] + n = mat_b.shape[1] + mat_c = torch.empty(m, n, dtype=mat_a.dtype, device=mat_a.device) + + ''' + NPU芯片更加亲和512B对齐场景,如下分块通用性能较好,可以使用autotune选取最优 + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 256 + ''' + + num_cores = get_npu_properties()["num_aicore"] + + matmul_kernel[(num_cores,)] ( + mat_a, + mat_b, + mat_c, + m, + n, + k, + num_cores + ) + return mat_c + + +if __name__ == "__main__": + M = 2048 + K = 7168 + N = 16384 + + mat_a = torch.randn([M, K], dtype = torch.bfloat16, device = "npu") + mat_b = torch.randn([K, N], dtype = torch.bfloat16, device = "npu") + + result = triton_matmul(mat_a, mat_b) + golden = torch.matmul(mat_a, mat_b) + + mask = golden.abs() < 1.0 + tmpatol = tmprtol = 2 ** -6 + try: + torch.testing.assert_close(result[mask], golden[mask], atol = tmpatol, rtol = 0) + torch.testing.assert_close(result[~mask], golden[~mask], atol = 0, rtol = tmprtol) + print("run matmul success") + except: + print(f"[ERROR] M={M} ,K={K}, N={N}存在精度问题") \ No newline at end of file diff --git a/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized.py b/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized.py new file mode 100644 index 000000000..c35bec9f2 --- /dev/null +++ b/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized.py @@ -0,0 +1,218 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +@triton.autotune( +configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + ], + key=["M", "N", "K"] +) +@triton.jit +def matmul_kernel( + mat_a, mat_b, mat_c, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + num_cores: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_TRESHHOLD: tl.constexpr, +): + pid = tl.program_id(axis=0) + task_m_idx = 0 + task_n_idx = 0 + + ''' + 水平分核方式每个任务块编号如下 + [0, 1, 2, 3, 4, 5, 6, 7] + [8, 9, 10, 11, 12, 13, 14, 15] + [16, 17, 18, 19, 20, 21, 22, 23] + [24, 25, 26, 27, 28, 29, 30, 31] + [32, 33, 34, 35, 36, 37, 38, 39] + [40, 41, 42, 43, 44, 45, 46, 47] + [48, 49, 50, 51, 52, 53, 54, 55] + [56, 57, 58, 59, 60, 61, 62, 63] + 0核处理 0 20 40 60 4块任务 + 1核处理 1 21 41 61 4块任务 + 2核处理 2 22 42 62 4块任务 + ... + 19核处理 19 39 59 3块任务 + + 大shape下如果使用传统水平分核方式,会有如下问题 + 1:同一时间大量核心需要访问同一块左矩阵内存,产生Bank冲突,导致硬件访问效率降低 + 2:当完成一整行mat_c运算时,已经将所有右矩阵数据全部使用上,右矩阵较大时会超过L2Cache的容量上限, + 从而导致L2Cache的搬入及换出,此后每行运算都会或多或少产生CacheMiss,导致L2Cche命中率较低,影响 + 算子执行效率 + 此处使用8 * 8对角线分核方式可以按8 * 8的方块沿对角线方向分核计算,可以很大程度优化上面两点。 + + 此处以8*8对角线分核为例,实际以BLOCK_TRESHHOLD为tune参数选择最优的阈值 + 8 * 8 对角线分核方式中,每8 * 8分格内任务块编号如下 + [0, 8, 16, 24, 32, 40, 48, 56] + [57, 1, 9, 17, 25, 33, 41, 49] + [50, 58, 2, 10, 18, 26, 34, 42] + [43, 51, 59, 3, 11, 19, 27, 35] + [36, 44, 52, 60, 4, 12, 20, 28] + [29, 37, 45, 53, 61, 5, 13, 21] + [22, 30, 38, 46, 54, 62, 6, 14] + [15, 23, 31, 39, 47, 55, 63, 7] + + M轴方向超过8个基本块时,使用对角线分核可以明显减小Bank冲突 + 当右矩阵大小超过L2Cache大小时,采取对角线分核可以提升L2Cache利用率 + 所以当矩阵在M和N方向均超过8块时使能对角线分核即可有优化,当右矩阵大小超过L2Cache大小时优化效果尤为明显 + ''' + NUM_BLOCKS_M = triton.cdiv(M, BLOCK_M) + NUM_BLOCKS_N = triton.cdiv(N, BLOCK_N) + NUM_BLOCKS = NUM_BLOCKS_M * NUM_BLOCKS_N + #当任务量较多时,可以使能对角线分核策略进行优化 + if NUM_BLOCKS_M >= BLOCK_TRESHHOLD and NUM_BLOCKS_N >= BLOCK_TRESHHOLD: + for block_idx in range ( + pid, NUM_BLOCKS, num_cores + ): + #8 * 8 对角线分核代码实现 + curThresholdM = BLOCK_TRESHHOLD if block_idx < (NUM_BLOCKS_M // BLOCK_TRESHHOLD * BLOCK_TRESHHOLD) * NUM_BLOCKS_N else NUM_BLOCKS_M % BLOCK_TRESHHOLD + curThresholdM_thresholdN = curThresholdM * BLOCK_TRESHHOLD + curThresholdN = BLOCK_TRESHHOLD if block_idx % (NUM_BLOCKS_N * BLOCK_TRESHHOLD) < (curThresholdM * NUM_BLOCKS_N) // curThresholdM_thresholdN * curThresholdM_thresholdN else NUM_BLOCKS_N % BLOCK_TRESHHOLD + localRelativeBlock = block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) % (BLOCK_TRESHHOLD * curThresholdM) + task_m_idx = localRelativeBlock % curThresholdM + block_idx // (BLOCK_TRESHHOLD * NUM_BLOCKS_N) * BLOCK_TRESHHOLD + #求最小公倍数,方便求基本块的坐标 + x, y = curThresholdM, curThresholdN if curThresholdM > curThresholdN else curThresholdN, curThresholdM + while y != 0: + x, y = y, x % y + lcm = curThresholdM * curThresholdN // x + task_n_idx = (localRelativeBlock + (localRelativeBlock // lcm)) % curThresholdN + block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) // curThresholdM_thresholdN * BLOCK_TRESHHOLD + + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N),dtype = tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( + k_start + tl.arange(0, BLOCK_K) + )[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K + )[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask = mat_a_mask, other = 0.0) + tl.compile_hint(mat_a_block, "dot_pad_only_k") + mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask = mat_b_mask, other = 0.0) + tl.compile_hint(mat_b_block, "dot_pad_only_k") + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask = mat_c_mask) + else: + #传统顺序分核 + for block_idx in range ( + pid, NUM_BLOCKS, num_cores + ): + task_m_idx = block_idx // NUM_BLOCKS_N + task_n_idx = block_idx % NUM_BLOCKS_N + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N),dtype = tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( + k_start + tl.arange(0, BLOCK_K) + )[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K + )[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask = mat_a_mask, other = 0.0) + tl.compile_hint(mat_a_block, "dot_pad_only_k") + mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask = mat_b_mask, other = 0.0) + tl.compile_hint(mat_b_block, "dot_pad_only_k") + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask = mat_c_mask) + +def triton_matmul( + mat_a, + mat_b, +): + m = mat_a.shape[0] + k = mat_a.shape[1] + n = mat_b.shape[1] + mat_c = torch.empty(m, n, dtype=mat_a.dtype, device=mat_a.device) + + ''' + NPU芯片更加亲和512B对齐场景,如下分块通用性能较好,可以使用autotune选取最优 + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 256 + ''' + + num_cores = get_npu_properties()["num_aicore"] + + matmul_kernel[(num_cores,)] ( + mat_a, + mat_b, + mat_c, + m, + n, + k, + num_cores + ) + return mat_c + + +if __name__ == "__main__": + M = 2048 + K = 7168 + N = 16384 + + mat_a = torch.randn([M, K], dtype = torch.bfloat16, device = "npu") + mat_b = torch.randn([K, N], dtype = torch.bfloat16, device = "npu") + + result = triton_matmul(mat_a, mat_b) + golden = torch.matmul(mat_a, mat_b) + + mask = golden.abs() < 1.0 + tmpatol = tmprtol = 2 ** -6 + try: + torch.testing.assert_close(result[mask], golden[mask], atol = tmpatol, rtol = 0) + torch.testing.assert_close(result[~mask], golden[~mask], atol = 0, rtol = tmprtol) + print("run matmul success") + except: + print(f"[ERROR] M={M} ,K={K}, N={N}存在精度问题") \ No newline at end of file diff --git a/third_party/ascend/examples/tutorials/14-accuracy-comparison.py b/third_party/ascend/examples/tutorials/14-accuracy-comparison.py new file mode 100644 index 000000000..e44006d9f --- /dev/null +++ b/third_party/ascend/examples/tutorials/14-accuracy-comparison.py @@ -0,0 +1,153 @@ +import torch +import triton +import triton.language as tl + + +def test_add(x0, x1): + """ + 测试 Triton 实现的向量加法与 PyTorch 的结果,精度比对是否一致。 + + 步骤: + 1. 使用 PyTorch 计算参考结果(torch_ref) + 2. 使用 Triton 编写 kernel 并计算结果(triton_cal) + 3. 调用 accuracy_comparison 进行精度比对 + """ + + # 1. 使用 PyTorch 作为参考实现(golden truth) + def torch_func(x0, x1): + res = x0 + x1 + return res + + # 2. 定义 Triton kernel(在 NPU/GPU 上执行) + @triton.jit + def triton_kernel_add( + out_ptr0, # 输出指针:结果存储位置 + in_ptr0, # 输入指针0:x0 的起始地址 + in_ptr1, # 输入指针1:x1 的起始地址 + XS: tl.constexpr # constexpr 参数:向量长度,在编译时确定 + ): + # 生成 [0, 1, 2, ..., XS-1] 的索引数组 + idx = tl.arange(0, XS) + # 从 in_ptr0 + idx 处加载 x0 的值 + tmp0 = tl.load(in_ptr0 + idx) + # 从 in_ptr1 + idx 处加载 x1 的值 + tmp1 = tl.load(in_ptr1 + idx) + # 执行加法 + tmp2 = tmp0 + tmp1 + # 将结果写入 out_ptr0 + idx + tl.store(out_ptr0 + idx, tmp2) + + # 3. Triton 封装函数:调用 kernel 并返回结果 + def triton_func(x0, x1): + y0 = torch.empty_like(x0) # 创建与输入形状、dtype 相同的输出张量 + # 启动 kernel:grid = [1, 1, 1] 表示仅使用一个 block + # 注意:XS 必须作为参数传入,因为它是 tl.constexpr 类型 + triton_kernel_add[1, 1, 1](y0, x0, x1, XS=x0.numel()) + return y0 + + # 4. 获取参考结果和 Triton 计算结果 + torch_ref = torch_func(x0, x1) + triton_cal = triton_func(x0, x1) + + # 5. 精度比对 + accuracy_comparison(triton_cal, torch_ref) + + # 6. 打印成功信息 + print( + f"== dtype:{triton_cal.dtype} == The accuracy comparison between triton_result and torch_result was successful.") + + +def accuracy_comparison(y_cal, y_ref): + """ + 精度比对函数:根据数据类型选择合适的比对策略。 + + 不同数据类型的处理策略: + - 浮点类型(float16/32, bfloat16):使用 torch.testing.assert_close,设置相对/绝对误差容限 + - 整数类型(int8/16/32/64):要求完全相等(torch.equal) + - 布尔类型(bool):CPU 上严格比较(避免设备差异) + """ + # 检查输出数据类型是否一致 + assert y_cal.dtype == y_ref.dtype, f"dtype mismatch: {y_cal.dtype} vs {y_ref.dtype}" + tensor_dtype = y_cal.dtype + + # 将张量移动到 NPU(假设测试在 NPU 上进行) + y_cal = y_cal.npu() + y_ref = y_ref.npu() + + # 根据数据类型选择不同的比对方式 + if tensor_dtype == torch.float16: + # float16 精度较低,允许稍大误差 + torch.testing.assert_close(y_ref, y_cal, rtol=1e-3, atol=1e-3, equal_nan=True) + elif tensor_dtype == torch.bfloat16: + # bfloat16 精度更低,建议转为 float32 再比较 + torch.testing.assert_close( + y_ref.to(torch.float32), + y_cal.to(torch.float32), + rtol=1e-3, + atol=1e-3, + equal_nan=True + ) + elif tensor_dtype == torch.float32: + # float32 精度较高,使用更严格的容差 + torch.testing.assert_close(y_ref, y_cal, rtol=1e-4, atol=1e-4, equal_nan=True) + elif tensor_dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint32]: + # 整数类型应完全相等 + assert torch.equal(y_cal, y_ref), f"Integer tensors are not equal for dtype {tensor_dtype}" + elif tensor_dtype == torch.bool: + # 布尔类型建议在 CPU 上比较,避免设备间布尔表示差异 + assert torch.equal(y_cal.cpu(), y_ref.cpu()), "Boolean tensors are not equal" + else: + raise ValueError(f'Invalid or unsupported tensor dtype: {tensor_dtype}') + + +# ======================== +# 主程序入口 +# ======================== +if __name__ == "__main__": + # 向量长度 + N = 1024 + # 整数随机数范围 + low = 1 + high = 100 + + # 创建各种数据类型的测试张量(已移至 NPU) + x0_fp32 = torch.rand((N,), dtype=torch.float32).npu() + x1_fp32 = torch.rand((N,), dtype=torch.float32).npu() + + x0_fp16 = torch.rand((N,), dtype=torch.float16).npu() + x1_fp16 = torch.rand((N,), dtype=torch.float16).npu() + + x0_bf16 = torch.rand((N,), dtype=torch.bfloat16).npu() + x1_bf16 = torch.rand((N,), dtype=torch.bfloat16).npu() + + x0_i64 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int64).npu() + x1_i64 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int64).npu() + + x0_i32 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int32).npu() + x1_i32 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int32).npu() + + x0_i16 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int16).npu() + x1_i16 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int16).npu() + + x0_i8 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int8).npu() + x1_i8 = torch.randint(low=low, high=high, size=(N,), dtype=torch.int8).npu() + + x0_i1 = torch.randint(low=0, high=2, size=(N,)).bool().npu() + x1_i1 = torch.randint(low=0, high=2, size=(N,)).bool().npu() + + # 测试用例列表:(名称, x0, x1) + test_cases = [ + ('fp32', x0_fp32, x1_fp32), + ('fp16', x0_fp16, x1_fp16), + ('bf16', x0_bf16, x1_bf16), + ('i64', x0_i64, x1_i64), + ('i32', x0_i32, x1_i32), + ('i16', x0_i16, x1_i16), + ('i8', x0_i8, x1_i8), + ('i1', x0_i1, x1_i1), + ] + + # 遍历所有测试用例 + for dtype_name, x0, x1 in test_cases: + print(f"Running test for {dtype_name}...") + test_add(x0, x1) diff --git a/third_party/ascend/language/ascend/__init__.py b/third_party/ascend/language/ascend/__init__.py new file mode 100644 index 000000000..229b57d87 --- /dev/null +++ b/third_party/ascend/language/ascend/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/third_party/ascend/language/ascend/libdevice.py b/third_party/ascend/language/ascend/libdevice.py new file mode 100644 index 000000000..9098856ec --- /dev/null +++ b/third_party/ascend/language/ascend/libdevice.py @@ -0,0 +1,155 @@ +from triton.language import core + +@core.extern +def reciprocal(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_recipf", core.dtype("fp32")), + (core.dtype("fp16"),): ("__hmf_recipDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_log1pf", core.dtype("fp32")), + (core.dtype("fp16"),): ("__hmf_log1pDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def relu(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_reluf", core.dtype("fp32")), + (core.dtype("fp16"),): ("__hmf_reluDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_isinf", core.dtype("int1")), + (core.dtype("fp16"),): ("__hmf_isinf", core.dtype("int1")), + (core.dtype("bf16"),): ("__hmf_isinf", core.dtype("int1")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_tanf", core.dtype("fp32")), + (core.dtype("fp16"),): ("__hmf_tanDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_atanf", core.dtype("fp32")), + (core.dtype("fp16"),): ("__hmf_atanDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_tanhf", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_tanhDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_ilogbf", core.dtype("fp32")), + (core.dtype("fp16"),): ("__hmf_ilogbDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_ldexpf", core.dtype("fp32")), + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_ldexpDh", core.dtype("fp16")), + }, is_pure=True, _builder=_builder) + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_powf", core.dtype("fp32")), + (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_powf", core.dtype("fp16")), + (core.dtype("bf16"), core.dtype("bf16")): ("__hmf_powf", core.dtype("bf16")), + (core.dtype("int64"), core.dtype("int64")): ("__hmf_powi", core.dtype("int64")), + (core.dtype("int32"), core.dtype("int32")): ("__hmf_powi", core.dtype("int32")), + (core.dtype("int16"), core.dtype("int16")): ("__hmf_powi", core.dtype("int16")), + (core.dtype("int8"), core.dtype("int8")): ("__hmf_powi", core.dtype("int8")), + }, is_pure=True, _builder=_builder) + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"),): ("__hmf_isnan", core.dtype("int1")), + (core.dtype("fp16"),): ("__hmf_isnan", core.dtype("int1")), + (core.dtype("bf16"),): ("__hmf_isnan", core.dtype("int1")), + }, is_pure=True, _builder=_builder) + +@core.extern +def flip(arg0, arg1=None, _builder=None): + if arg1 == None: + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("bf16"), ): ("__hmf_flipDhb", core.dtype("bf16")), + (core.dtype("fp16"), ): ("__hmf_flipDh", core.dtype("fp16")), + (core.dtype("fp32"), ): ("__hmf_flipf", core.dtype("fp32")), + (core.dtype("int8"), ): ("__hmf_flipi8", core.dtype("int8")), + (core.dtype("int16"), ): ("__hmf_flipi16", core.dtype("int16")), + (core.dtype("int32"), ): ("__hmf_flipi32", core.dtype("int32")), + (core.dtype("uint32"), ): ("__hmf_flipui32", core.dtype("uint32")), + (core.dtype("int64"), ): ("__hmf_flipi64", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("bf16"), core.dtype("int32")): ("__hmf_flipDhb", core.dtype("bf16")), + (core.dtype("fp16"), core.dtype("int32")): ("__hmf_flipDh", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("int32")): ("__hmf_flipf", core.dtype("fp32")), + (core.dtype("int8"), core.dtype("int32")): ("__hmf_flipi8", core.dtype("int8")), + (core.dtype("int16"), core.dtype("int32")): ("__hmf_flipi16", core.dtype("int16")), + (core.dtype("int32"), core.dtype("int32")): ("__hmf_flipi32", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("int32")): ("__hmf_flipui32", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int32")): ("__hmf_flipi64", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + +@core.extern +def atan2(arg0, _builder=None): + core.static_print("tl.atan2 is unsupported for now. Use libdevice.atan2 instead.") + core.static_assert(False) + +@core.extern +def div_rz(arg0, arg1, _builder=None): + core.static_print("tl.div_rz is unsupported for now. Use libdevice.div_rz instead.") + core.static_assert(False) + +@core.extern +def fmod(arg0, arg1, _builder=None): + core.static_print("tl.fmod is unsupported for now. Use libdevice.fmod instead.") + core.static_assert(False) + +@core.extern +def trunc(arg0, _builder=None): + core.static_print("tl.trunc is unsupported for now. Use libdevice.trunc instead.") + core.static_assert(False) + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_roundf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) \ No newline at end of file diff --git a/third_party/ascend/test/CMakeLists.txt b/third_party/ascend/test/CMakeLists.txt new file mode 100644 index 000000000..e63d772a9 --- /dev/null +++ b/third_party/ascend/test/CMakeLists.txt @@ -0,0 +1,28 @@ + +llvm_canonicalize_cmake_booleans( + MLIR_ENABLE_BINDINGS_PYTHON +) + +configure_lit_site_cfg( + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py + MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py +) + +set(TRITON_ADAPTER_TEST_DEPENDS + triton-adapter-opt +) + +set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck") +set(LIT_ARGS "-Dfilecheck=${FILECHECK_PATH}") + +add_lit_testsuite(check-triton-adapter-lit-tests "Running the triton-adapter regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + ARGS ${LIT_ARGS} + DEPENDS ${TRITON_ADAPTER_TEST_DEPENDS} + ) + +set_target_properties(check-triton-adapter-lit-tests PROPERTIES FOLDER "Tests") + +add_lit_testsuites(TRITON_ADAPTER_TIT_TESTS ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TRITON_ADAPTER_TEST_DEPENDS}) diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/add_kernel_attr.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/add_kernel_attr.mlir new file mode 100644 index 000000000..56fc3422f --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/add_kernel_attr.mlir @@ -0,0 +1,39 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func public @num_programs(%arg0: !tt.ptr) { + %0 = tt.get_num_programs x : i32 + %1 = tt.get_num_programs y : i32 + %2 = tt.get_num_programs z : i32 + %3 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> + %4 = tt.make_range {end = 2 : i32, start = 1 : i32} : tensor<1xi32> + %5 = tt.make_range {end = 3 : i32, start = 2 : i32} : tensor<1xi32> + %6 = tt.splat %arg0 : !tt.ptr -> tensor<1x!tt.ptr> + %7 = tt.addptr %6, %3 : tensor<1x!tt.ptr>, tensor<1xi32> + %8 = tt.splat %0 : i32 -> tensor<1xi32> + tt.store %7, %8 evictionPolicy = evict_last : tensor<1x!tt.ptr> + %9 = tt.addptr %6, %4 : tensor<1x!tt.ptr>, tensor<1xi32> + %10 = tt.splat %1 : i32 -> tensor<1xi32> + tt.store %9, %10 evictionPolicy = evict_last : tensor<1x!tt.ptr> + %11 = tt.addptr %6, %5 : tensor<1x!tt.ptr>, tensor<1xi32> + %12 = tt.splat %2 : i32 -> tensor<1xi32> + tt.store %11, %12 evictionPolicy = evict_last : tensor<1x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: module { +// CHECK: func.func @num_programs( +// CHECK-SAME: %arg0: memref, %arg1: memref, %arg2: memref {tt.tensor_kind = 1 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1]>> +// CHECK: %[[COMMON_EMPTY_TENSOR:.*]] = tensor.empty() : tensor<1xi32> +// CHECK: %1 = linalg.fill ins(%arg3 : i32) outs(%[[COMMON_EMPTY_TENSOR]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: bufferization.materialize_in_destination %1 in writable %reinterpret_cast : (tensor<1xi32>, memref<1xi32, strided<[1]>>) -> () +// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [1], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1], offset: 1>> +// CHECK: %2 = linalg.fill ins(%arg4 : i32) outs(%[[COMMON_EMPTY_TENSOR]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: bufferization.materialize_in_destination %2 in writable %reinterpret_cast_0 : (tensor<1xi32>, memref<1xi32, strided<[1], offset: 1>>) -> () +// CHECK: %reinterpret_cast_1 = memref.reinterpret_cast %arg2 to offset: [2], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1], offset: 2>> +// CHECK: %3 = linalg.fill ins(%arg5 : i32) outs(%[[COMMON_EMPTY_TENSOR]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: bufferization.materialize_in_destination %3 in writable %reinterpret_cast_1 +// CHECK: return +// CHECK: } +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_2d_example.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_2d_example.mlir new file mode 100644 index 000000000..cd2e6489c --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_2d_example.mlir @@ -0,0 +1,68 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr, + %arg3 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> + %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> + // offset = [%arg3,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}: tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %6 = arith.constant 5 : i32 + %splat6 = tt.splat %6 : i32 -> tensor<4x256xi32> + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,5] + %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> + // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %10 = tt.load %9 : tensor<4x256x!tt.ptr> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> + %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %13 = tt.load %12 : tensor<4x256x!tt.ptr> + %14 = arith.addf %10, %13 : tensor<4x256xbf16> + %15 = tt.splat %arg2 : !tt.ptr -> tensor<4x256x!tt.ptr> + %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + tt.store %16, %14 : tensor<4x256x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_2:.*]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[VAL_3:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index +// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, 5] : memref to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[VAL_9]], %[[VAL_10]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<4x256xbf16> +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, 5] : memref to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<4x256xbf16> +// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_15]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_11]] : tensor<4x256xbf16>) { +// CHECK: ^bb0(%[[VAL_17:.*]]: bf16, %[[VAL_18:.*]]: bf16, %[[VAL_19:.*]]: bf16): +// CHECK: %[[VAL_20:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : bf16 +// CHECK: linalg.yield %[[VAL_20]] : bf16 +// CHECK: } -> tensor<4x256xbf16> +// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, 5] : memref to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_add_value.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_add_value.mlir new file mode 100644 index 000000000..5d4ab165e --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_add_value.mlir @@ -0,0 +1,64 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32, + %arg3 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg2splat = tt.splat %arg2 : i32 -> tensor<4x256xi32> + %offset2 = arith.addi %2, %arg2splat : tensor<4x256xi32> + // offset = [%arg2,0], size = [4,256], stride = [1,0] + %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> + %offset3 = arith.addi %offset2, %arg3splat : tensor<4x256xi32> + // offset = [%arg2+%arg3,0], size = [4,256], stride = [1,0] + %c10 = arith.constant 10 : i32 + %c10splat = tt.splat %c10 : i32 -> tensor<4x256xi32> + %offset4 = arith.addi %offset3, %c10splat : tensor<4x256xi32> + // offset = [%arg2+%arg3+10,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : i32 -> tensor<4x256xi32> + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,6] + %7 = arith.addi %offset4, %scale5: tensor<4x256xi32> + // offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] + %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>,tensor<4x256xi32> + // source = %arg0, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] + %10 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> + %11 = tt.addptr %10, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source = %arg1, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] + %12 = tt.load %9 : tensor<4x256x!tt.ptr> + tt.store %11, %12 : tensor<4x256x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_2]] : i32 to index +// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_3]] : i32 to index +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_10]] : index +// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_8]] : index +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_12]]], sizes: [4, 256], strides: [1, 6] : memref to memref<4x256xbf16, strided<[1, 6], offset: ?>> +// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_12]]], sizes: [4, 256], strides: [1, 6] : memref to memref<4x256xbf16, strided<[1, 6], offset: ?>> +// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[VAL_13]], %[[VAL_19]] : memref<4x256xbf16, strided<[1, 6], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_19]] restrict writable : memref<4x256xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_20]] in writable %[[VAL_18]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_bitcast.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_bitcast.mlir new file mode 100644 index 000000000..84a3cda31 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_bitcast.mlir @@ -0,0 +1,122 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s --split-input-file | FileCheck %s +module { + // CHECK-LABEL: func.func @addptr_bitcast + tt.func public @addptr_bitcast(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %c64_i64 = arith.constant 64 : i64 + %c1_i64 = arith.constant 1 : i64 + %cst_2 = arith.constant dense<0> : tensor<64xi8> + %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<64x!tt.ptr>, tensor<64xi32> + %81 = tt.bitcast %8 : tensor<64x!tt.ptr> -> tensor<64x!tt.ptr> + // CHECK: %[[SRC:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [64], strides: [1] : memref to memref<64xi8, strided<[1]>> + // CHECK: %[[DST:.*]] = memref.alloc() : memref<64xi8> + // CHECK: memref.copy %[[SRC]], %[[DST]] : memref<64xi8, strided<[1]>> to memref<64xi8> + %10 = tt.load %81 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x!tt.ptr> + %11 = arith.cmpi ne, %10, %cst_2 : tensor<64xi8> + %12 = arith.uitofp %11 : tensor<64xi1> to tensor<64xf32> + %16 = tt.make_tensor_ptr %arg1, [%c64_i64], [%c1_i64], [%c0_i32] {order = array} : > + tt.store %16, %12 : !tt.ptr> + tt.return + } +} + +// ----- + +module { + // CHECK-LABEL: func.func @addptr_bitcast2 + tt.func public @addptr_bitcast2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7 = tt.load %6 : tensor<1024x!tt.ptr> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + %11 = arith.cmpf olt, %7, %10 : tensor<1024xf16> + %12 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %13 = tt.addptr %12, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %14 = tt.bitcast %13 : tensor<1024x!tt.ptr> -> tensor<1024x!tt.ptr> + // CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [[[VAR_1:%.+]]], sizes: [1024], strides: [1] : memref to memref<1024xi8, strided<[1], offset: ?>> + %15 = arith.extui %11 : tensor<1024xi1> to tensor<1024xi8> + tt.store %14, %15 : tensor<1024x!tt.ptr> + tt.return + } +} + +// ----- + +module { + // CHECK-LABEL: func.func @addptr_bitcast3 + // CHECK-ORIG-SAME: %[[ARG0_I8:.*]]: memref<*xi8> + tt.func public @addptr_bitcast3(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<1xi8> + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %7 = tt.load %6 : tensor<1024x!tt.ptr> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %9 : tensor<1024x!tt.ptr> + %11 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %12 = tt.bitcast %11 : !tt.ptr -> !tt.ptr + // CHECK: %[[REINTERPRET_CAST_2:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [%[[PID_0:.*]]], sizes: [1], strides: [1] : memref to memref<1xi8, strided<[1], offset: ?>> + %13 = tt.splat %12 : !tt.ptr -> tensor<1x!tt.ptr> + %14 = tt.load %13 : tensor<1x!tt.ptr> + %15 = arith.cmpi ne, %14, %cst : tensor<1xi8> + %16 = tt.broadcast %15 : tensor<1xi1> -> tensor<1024xi1> + %17 = arith.select %16, %7, %10 : tensor<1024xi1>, tensor<1024xf16> + tt.store %6, %17 : tensor<1024x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK-LABEL: func @addptr_bitcast_bool_1_tensor +tt.func public @addptr_bitcast_bool_1_tensor(%arg0 : !tt.ptr, %b : i32) -> () { + %16 = tt.splat %b : i32 -> tensor<1xi32> + %17 = tt.splat %arg0 : !tt.ptr -> tensor<1x!tt.ptr> + //CHECK: %[[VAR_0:.*]] = arith.index_cast [[ARG_B:%.+]] : i32 to index + // CHECK: %[[RECAST:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [%0], sizes: [1], strides: [1] : memref to memref<1xi8, strided<[1], offset: ?>> + %18 = tt.addptr %17, %16 : tensor<1x!tt.ptr>, tensor<1xi32> + %19 = tt.bitcast %18 : tensor<1x!tt.ptr> -> tensor<1x!tt.ptr> + // CHECK: memref.copy %[[RECAST]], %[[ALLOC:.*]] + %20 = tt.load %19 evictionPolicy = evict_last : tensor<1x!tt.ptr> + tt.store %19, %20 : tensor<1x!tt.ptr> + tt.return +} + +// ----- + +// CHECK-LABEL: func @addptr_bitcast_bool_1_scalar +tt.func public @addptr_bitcast_bool_1_scalar(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { + %c0_i32 = arith.constant 0 : i32 + // CHECK: %[[VAL_c0:.*]] = arith.constant 0 : index + // CHECK: %[[RECAST0:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi8, strided<[1]>> + %0 = tt.bitcast %arg0 : !tt.ptr -> !tt.ptr + // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1xi8> + // CHECK: memref.copy %[[RECAST0]], %[[ALLOC]] : memref<1xi8, strided<[1]>> to memref<1xi8> + // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<1xi8> + // CHECK: %[[VAL_0:.*]] = tensor.extract %[[TENSOR]]{{\[}}%[[VAL_c0]]] : tensor<1xi8> + %1 = tt.load %0 : !tt.ptr + %2 = tt.addptr %arg1, %c0_i32 : !tt.ptr, i32 + %3 = tt.bitcast %2 : !tt.ptr -> !tt.ptr + // CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1xi8> + // CHECK: %[[VAL_2:.*]] = linalg.fill ins(%[[VAL_0]] : i8) outs(%[[VAL_1]] : tensor<1xi8>) -> tensor<1xi8> + // CHECK: %reinterpret_cast_0 = memref.reinterpret_cast [[ARG_1:%.+]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi8, strided<[1]>> + // CHECK: bufferization.materialize_in_destination %[[VAL_2]] in writable %reinterpret_cast_0 : (tensor<1xi8>, memref<1xi8, strided<[1]>>) -> () + tt.store %3, %1 : !tt.ptr + tt.return +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_dim1.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_dim1.mlir new file mode 100644 index 000000000..300ae985f --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_dim1.mlir @@ -0,0 +1,114 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// XFAIL: * +// This test crashes because tt.broadcast's folder tries to cast +// the src operand to a RankedTensorType value, but the TritonToLinalg +// pass has already replaced the src with a value of a different type. +// We're going to retire the monolith triton-to-linalg pass which prevents +// this problem. xfailing the test for now. code of line28 error + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + + %arg1 : i32 + ) + { + %0 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + + %splat_arg0 = tt.splat %arg0 : !tt.ptr -> tensor<1x256x!tt.ptr> + %2 = tt.addptr %splat_arg0, %1 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + + // 1x256 pointer should have meaningful stride in outer dimension + %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<1x256x!tt.ptr> + + %4 = tt.splat %arg1 : i32 -> tensor<1x256xi32> + // 1x256 pointer should have meaningful stride in outer dimension + %5 = tt.addptr %2, %4 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + tt.store %5, %3 : tensor<1x256x!tt.ptr>, tensor<1x256x!tt.ptr> + + %10 = arith.constant 0.0 : bf16 + %11 = tt.splat %10 : bf16 -> tensor<4x256xbf16> + + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %c256 = arith.constant 256 : i32 + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %11, %ptr = %2) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { + %bptr = tt.broadcast %ptr : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> + + %20 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %i_i32 = arith.index_cast %i : index to i32 + %21 = arith.muli %c256, %i_i32 : i32 + %22 = tt.splat %21 : i32 -> tensor<4xi32> + %23 = arith.muli %20, %22 : tensor<4xi32> + %24 = tt.expand_dims %23 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %25 = tt.broadcast %24 : tensor<4x1xi32> -> tensor<4x256xi32> + + // %bptr should have zero stride and %30 should have correct stride + %30 = tt.addptr %bptr, %25 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + %31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> + %32 = arith.addf %sum_iter, %31 : tensor<4x256xbf16> + + %40 = tt.splat %c256 : i32 -> tensor<1x256xi32> + %41 = tt.addptr %ptr, %40 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + + scf.yield %32, %41 : tensor<4x256xbf16>, tensor<1x256x!tt.ptr> + } + + %31 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %splat_c256 = tt.splat %c256 : i32 -> tensor<4xi32> + %32 = arith.muli %31, %splat_c256 : tensor<4xi32> + %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %34 = tt.broadcast %33 : tensor<4x1xi32> -> tensor<4x256xi32> + %35 = tt.broadcast %2 : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> + %36 = tt.addptr %35, %34 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + tt.store %36, %sum_out : tensor<4x256x!tt.ptr>, tensor<4x256x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @kernel +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4x256xbf16> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_0_]] : tensor<4x256xbf16>) -> tensor<4x256xbf16> +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1, 256], strides: [256, 1] : memref<*xbf16> to memref<1x256xbf16, strided<[256, 1]>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1x256xbf16> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<1x256xbf16, strided<[256, 1]>> to memref<1x256xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<1x256xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_1_]] : i32 to index +// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [1, 256], strides: [256, 1] : memref<*xbf16> to memref<1x256xbf16, strided<[256, 1], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[VAR_reinterpret_cast_0_]] +// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_1_]], [[VAR_arg7_:%.+]] = [[CST_0_]], [[VAR_arg8_:%.+]] = [[CST_0_]]) -> (tensor<4x256xbf16>, index, index) { +// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_arg5_]] : index to i32 +// CHECK: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[CST_256_1_]] : i32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : i32 to index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_arg8_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [4, 256], strides: {{.}}[[VAR_7_]], [[CST_1_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16> +// CHECK: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16> +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg6_]], [[VAR_9_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_arg6_]] : tensor<4x256xbf16>) { +// CHECK: ^bb0([[in1:%.+]]: bf16, [[in2:%.+]]: bf16, [[out:%.+]]: bf16): +// CHECK: [[VAR_13_:%.+]] = arith.addf [[in1]], [[in2]] : bf16 +// CHECK: linalg.yield [[VAR_13_]] : bf16 +// CHECK: } -> tensor<4x256xbf16> +// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_256_]] : index +// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_11_]], [[VAR_arg8_]] : index +// CHECK: scf.yield [[VAR_10_]], [[VAR_12_]], [[CST_0_]] : tensor<4x256xbf16>, index, index +// CHECK: } +// CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<4x256xbf16, strided<[?, 1]>> +// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#0 in writable [[VAR_reinterpret_cast_1_]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir new file mode 100644 index 000000000..c7c548a30 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir @@ -0,0 +1,92 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr, + %arg3 : i32, + %arg4 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> + %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> + // offset = [%arg3,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %c5 = arith.constant 5 : i32 + %splat6 = tt.splat %c5 : i32 -> tensor<4x256xi32> + // scalar = 5 + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> // Why we never called the conversion function for the inputs here? + // offset = [0,0], size = [4,256], stride = [0,5] + %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> // Why we never called the conversion function for the inputs here? + // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> // Why is the input unknown + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %19 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> // this will be replaced with a memref.copy + %11 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> + %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %19, %ptr_iter = %12) -> (tensor<4x256xbf16>, tensor<4x256x!tt.ptr>) { + %20 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> + %sum = arith.addf %sum_iter, %20 : tensor<4x256xbf16> + // pointer updates + %17 = tt.splat %i_c3 : i32 -> tensor<4x256xi32> + // offset: [3, 0], size = [4, 256], stride [0, 0] + %ptr = tt.addptr %ptr_iter, %17 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg1, offset = [%arg3+%i, 0], size = [4, 256], stride = [1, 5] + scf.yield %sum, %ptr : tensor<4x256xbf16>, tensor<4x256x!tt.ptr> + } + %15 = tt.splat %arg2 : !tt.ptr -> tensor<4x256x!tt.ptr> + %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] + tt.store %16, %sum_out : tensor<4x256x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref, %[[VAL_2:.*]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32, %[[ARG_12:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index +// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_3]] : i32 to index +// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [4, 256], strides: [1, 5] : memref to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<4x256xbf16> +// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_13]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref to memref<4x256xbf16, strided<[?, ?], offset: ?>> +// CHECK: %[[VAL_19:.*]]:4 = scf.for %[[VAL_20:.*]] = %[[VAL_12]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_13]], %[[VAL_24:.*]] = %[[VAL_12]]) -> (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index, index) { +// CHECK: %[[VAL_25:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[VAL_22]], %[[VAL_25]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_25]] restrict writable : memref<4x256xbf16> +// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_21]], %[[VAL_26]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_21]] : tensor<4x256xbf16>) { +// CHECK: ^bb0(%[[VAL_28:.*]]: bf16, %[[VAL_29:.*]]: bf16, %[[VAL_30:.*]]: bf16): +// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_28]], %[[VAL_29]] : bf16 +// CHECK: linalg.yield %[[VAL_31]] : bf16 +// CHECK: } -> tensor<4x256xbf16> +// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_23]], %[[VAL_10]] : index +// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_32]], %[[VAL_24]] : index +// CHECK: %[[VAL_34:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_33]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref to memref<4x256xbf16, strided<[?, ?], offset: ?>> +// CHECK: scf.yield %[[VAL_35:.*]], %[[VAL_34]], %[[VAL_33]], %[[VAL_12]] : tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index, index +// CHECK: } +// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_13]]], sizes: [4, 256], strides: [1, 5] : memref to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_38:.*]]#0 in writable %[[VAL_37]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir new file mode 100644 index 000000000..b47b6bc9e --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir @@ -0,0 +1,73 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 + + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024, strides: 1 + + // gep operand is another gep' output, which is passed into the loop as varible, used after update + %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { + %6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> + + %8 = tt.broadcast %7 : tensor<256x1xi32> -> tensor<256x256xi32> + // sizes: [256, 256], offsets: [0, 0], strides: [1, 0] + + %9 = tt.make_range {end = 512 : i32, start = 256 : i32} : tensor<256xi32> + %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + + %11 = tt.broadcast %10 : tensor<1x256xi32> -> tensor<256x256xi32> + // sizes: [256, 256], offsets: [0, 256], strides: [0, 1] + + %12 = arith.addi %8, %11 : tensor<256x256xi32> + // sizes: [256, 256], offsets: [0, 256], strides: [1, 1] + + %13 = tt.expand_dims %ptr {axis = 1 : i32} : tensor<256x!tt.ptr> -> tensor<256x1x!tt.ptr> + %14 = tt.broadcast %13 : tensor<256x1x!tt.ptr> -> tensor<256x256x!tt.ptr> + + %15 = tt.addptr %14, %12 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> + // source: arg0, sizes: [256, 256], offsets: [1024 + i, 256], strides: [2, 1] + + // perform load + %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x256x!tt.ptr> + tt.store %15, %16 : tensor<256x256x!tt.ptr> + // pointer updates + %17 = tt.splat %i_c3 : i32 -> tensor<256xi32> + // sizes: 256, offsets: 3, strides: 0 + %ptr_iter = tt.addptr %ptr, %17 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024 + i, strides: 4 + scf.yield %ptr_iter : tensor<256x!tt.ptr> + } + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 12 : index +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_6]]) -> (index) { +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index +// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [256, 256], strides: {{\[}}%[[VAL_4]], 1] : memref to memref<256x256xbf16, strided<[?, 1], offset: ?>> +// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<256x256xbf16> +// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<256x256xbf16, strided<[?, 1], offset: ?>> to memref<256x256xbf16> +// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<256x256xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_16]] in writable %[[VAL_14]] +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : index +// CHECK: scf.yield %[[VAL_17]] : index +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir new file mode 100644 index 000000000..859d03f2e --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir @@ -0,0 +1,71 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c12 = arith.constant 12 : index + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024, strides: 1 + %3 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr> + %4 = tt.addptr %3, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg1, sizes: 256, offsets: 1024, strides: 1 + %_arg2, %_ptr_ld, %_arg3, %_ptr_st, %_arg4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%arg2 = %c1, %ptr_ld = %2, %arg3 = %c2, %ptr_st = %4, %arg4 = %c3) -> (index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index) { + // perform load + %5 = tt.load %ptr_ld {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> + tt.store %ptr_st, %5 : tensor<256x!tt.ptr> + // pointer updates + %cast3 = arith.index_cast %c3 : index to i32 + %6 = tt.splat %cast3 : i32 -> tensor<256xi32> + %ptr_ld_iter = tt.addptr %ptr_ld, %6 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024 + i*3, strides: 1 + %arg2_iter = arith.addi %arg2, %c3 : index + %arg3_iter = arith.addi %arg3, %c3 : index + %arg4_iter = arith.addi %arg4, %c3 : index + %7 = arith.addi %arg2_iter, %arg3_iter : index + %8 = arith.addi %7, %arg4_iter : index + %cast8 = arith.index_cast %8 : index to i32 + %9 = tt.splat %cast8 : i32 -> tensor<256xi32> + %ptr_st_iter = tt.addptr %ptr_st, %9 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg1, sizes: 256, offsets: 1024 + loop-carry variable*i, strides: 1 + scf.yield %arg2_iter, %ptr_ld_iter, %arg3_iter, %ptr_st_iter, %arg4_iter : index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index + } + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index +// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_14:.*]]:7 = scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_9]], %[[VAL_19:.*]] = %[[VAL_13]], %[[VAL_20:.*]] = %[[VAL_10]], %[[VAL_21:.*]] = %[[VAL_6]], %[[VAL_22:.*]] = %[[VAL_6]]) -> (index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index) { +// CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<256xbf16> +// CHECK: memref.copy %[[VAL_17]], %[[VAL_23]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> +// CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<256xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_24]] in writable %[[VAL_19]] +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_21]], %[[VAL_10]] : index +// CHECK: %[[VAL_26:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_10]] : index +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_10]] : index +// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_20]], %[[VAL_10]] : index +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]] : index +// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_22]], %[[VAL_31]] : index +// CHECK: %[[VAL_33:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_32]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: scf.yield %[[VAL_27]], %[[VAL_26]], %[[VAL_28]], %[[VAL_33]], %[[VAL_29]], %[[VAL_25]], %[[VAL_32]] : index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir new file mode 100644 index 000000000..ef448a845 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir @@ -0,0 +1,98 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024, strides: 1 + // gep operand is another gep' output, which is passed into the loop as varible, used after update + %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { + // pointer updates + %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> + // sizes: 256, offsets: 3, strides: 0 + %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024 + i, strides: 1 + // perform load + %3 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> + tt.store %ptr_iter, %3 : tensor<256x!tt.ptr> + scf.yield %ptr_iter : tensor<256x!tt.ptr> + } + // Expected output + // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) + // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) + // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) + // %subview = memref.subview %arg0, [%4][256][4] : memref<> -> memref<> <- generate subview on getelementptr (already done) + // ... + // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) + // } + // TODO: examples below are not supported since scf.for does not support returning a tensor type + // Example 3, gep operand is a vector of i32, which is passed into the loop as variable, pointer updated using step, used after update + //%_ptr3 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %1) -> (tensor<256xi32>) { + // // offset update + // %3 = tt.splat %c3 : i32 -> tensor<256xi32> + // %ptr_iter = arith.addi %3, %ptr : tensor<256xi32> + // // generate pointer + // %gep_ptr = tt.addptr %0, %ptr_iter : tensor<256x!tt.ptr> + // // perform load + // %4 = tt.load %gep_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> + // tt.store %gep_ptr, %4 : tensor<256x!tt.ptr> + // scf.yield %ptr_iter : tensor<256xi32> + //} + // Expected output + // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) + // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) + // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) + // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) + // ... + // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) + // } + //// Example 4, gep operand is a vector of i32, which is passed into the loop as variable, pointer updated using step, used before update + //%_ptr4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %1) -> (tensor<256xi32>) { + // // generate pointer + // %gep_ptr = tt.addptr %0, %ptr : tensor<256x!tt.ptr> + // + // // perform load + // %4 = tt.load %gep_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> + // tt.store %gep_ptr, %4 : tensor<256x!tt.ptr> + // // offset update + // %3 = tt.splat %c3 : i32 -> tensor<256xi32> + // %ptr_iter = arith.addi %3, %ptr : tensor<256xi32> + // scf.yield %ptr_iter : tensor<256xi32> + //} + // Expected output + // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) + // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) + // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) + // ... + // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) + // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) + // } + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[VAL_11:.*]] = %[[VAL_5]]) -> (index) { +// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_8]] : index +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_12]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16> +// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> +// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_13]] +// CHECK: scf.yield %[[VAL_12]] : index +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir new file mode 100644 index 000000000..768570fa9 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir @@ -0,0 +1,55 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr + ) + { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> + // source: null, sizes: 256, offsets: 1024, strides: 1 + %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> + // source: arg0, sizes: 256, offsets: 1024, strides: 1 + // Example 2, gep operand is another gep's output, which is passed into the loop as varible, used before update + %_ptr2 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { + // perform load + %3 = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> + tt.store %ptr, %3 : tensor<256x!tt.ptr> + // pointer updates + %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> + %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %ptr_iter : tensor<256x!tt.ptr> + } + // Expected output + // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) + // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) + // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) + // ... + // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) + // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) + // } + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_5]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_10:.*]]:2 = scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]], %[[VAL_13:.*]] = %[[VAL_5]]) -> (memref<256xbf16, strided<[?], offset: ?>>, index) { +// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16> +// CHECK: memref.copy %[[VAL_12]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> +// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_12]] +// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index +// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_16]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref to memref<256xbf16, strided<[?], offset: ?>> +// CHECK: scf.yield %[[VAL_17]], %[[VAL_16]] : memref<256xbf16, strided<[?], offset: ?>>, index +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_loopback.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_loopback.mlir new file mode 100644 index 000000000..c6bd315e4 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_loopback.mlir @@ -0,0 +1,52 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg2splat = tt.splat %arg2 : i32 -> tensor<4x256xi32> + %offset2 = arith.addi %2, %arg2splat : tensor<4x256xi32> + // offset = [%arg2,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : i32 -> tensor<4x256xi32> + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,6] + %7 = arith.addi %offset2, %scale5: tensor<4x256xi32> + // offset = [%arg2, 0], size = [4, 256], stride = [1, 6] + %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: arg0, offset = [%arg2, 0], size = [4, 256], stride = [1, 6] + %10 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> + %11 = tt.addptr %10, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: arg1, offset = [%arg2, 0], size = [4, 256], stride = [1, 6] + %12 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> + tt.store %11, %12 : tensor<4x256x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[VAL_2:.*]]: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_2]] : i32 to index +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_7]]], sizes: [4, 256], strides: [1, 6] : memref to memref<4x256xbf16, strided<[1, 6], offset: ?>> +// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_7]]], sizes: [4, 256], strides: [1, 6] : memref to memref<4x256xbf16, strided<[1, 6], offset: ?>> +// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[VAL_8]], %[[VAL_11]] : memref<4x256xbf16, strided<[1, 6], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<4x256xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_12]] in writable %[[VAL_10]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir new file mode 100644 index 000000000..6e048c3ab --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir @@ -0,0 +1,49 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32}:tensor<1024xi32> + %2 = tt.splat %0 : i32 -> tensor<1024xi32> + %3 = arith.addi %2, %1 : tensor<1024xi32> + //%3: splat(%0) + range(0, 1024) + //%3: offset = %0, size = 1024, stride = 1 + // vector and scalar are both constant + %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> + %c10 = arith.constant 10 : i32 + %5 = tt.splat %c10 : i32 -> tensor<1024xi32> + %6 = arith.muli %5, %4 : tensor<1024xi32> + //%6: splat(%c10)*range(2048, 4096); + //%6: offset = %c10*2048, size = 1024, stride = %c10*1 + %7 = arith.addi %3, %6 : tensor<1024xi32> + //%7: offset = %c10*2048 + %0, size = 1024, stride = %c10*1+1 + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + //source=%arg0 offset = %c10*2048 + pid0, size = 1024, stride = %c10*1+1 + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + //source=%arg1, offset = pid0, size = 1024, stride = 1 + %16 = tt.load %9 : tensor<1024x!tt.ptr> + tt.store %11, %16 : tensor<1024x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:[A-Za-z0-9_]+]]: memref, %[[ARG_1:[A-Za-z0-9_]+]]: memref, %[[VAL_0:[A-Za-z0-9_]+]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:[A-Za-z0-9_]+]]: memref {tt.tensor_kind = 1 : i32} +// CHECK-SAME: %[[ARG_4:[A-Za-z0-9_]+]]: i32, %[[ARG_5:[A-Za-z0-9_]+]]: i32, %[[ARG_6:[A-Za-z0-9_]+]]: i32, %[[ARG_7:[A-Za-z0-9_]+]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:[A-Za-z0-9_]+]]: i32, %[[ARG_10:[A-Za-z0-9_]+]]: i32 +// CHECK-SAME: ) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_7:[A-Za-z0-9_]+]] = arith.constant 20480 : index +// CHECK: %[[VAL_8:[A-Za-z0-9_]+]] = arith.index_cast %[[ARG_8]] : i32 to index +// CHECK: %[[VAL_9:[A-Za-z0-9_]+]] = arith.addi %[[VAL_8]], %[[VAL_7]] : index +// CHECK: %[[VAL_10:[A-Za-z0-9_]+]] = memref.reinterpret_cast %[[VAL_0]] to offset: [%[[VAL_9]]], sizes: [1024], strides: [11] : memref to memref<1024xbf16, strided<[11], offset: ?>> +// CHECK: %[[VAL_12:[A-Za-z0-9_]+]] = memref.reinterpret_cast %[[VAL_1]] to offset: [%[[VAL_8]]], sizes: [1024], strides: [1] : memref to memref<1024xbf16, strided<[1], offset: ?>> +// CHECK: %[[VAL_13:[A-Za-z0-9_]+]] = memref.alloc() : memref<1024xbf16> +// CHECK: memref.copy %[[VAL_10]], %[[VAL_13]] : memref<1024xbf16, strided<[11], offset: ?>> to memref<1024xbf16> +// CHECK: %[[VAL_14:[A-Za-z0-9_]+]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<1024xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_14]] in writable %[[VAL_12]] : (tensor<1024xbf16>, memref<1024xbf16, strided<[1], offset: ?>>) -> () +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir new file mode 100644 index 000000000..64ed953ab --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir @@ -0,0 +1,52 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.get_program_id x : i32 + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32}:tensor<1024xi32> + %2 = tt.splat %0 : i32 -> tensor<1024xi32> + %3 = arith.addi %2, %1 : tensor<1024xi32> + //%3: splat(%0) + range(0, 1024) + //%3: offset = %0, size = 1024, stride = 1 + // vector is constant, scalar is value + %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> + %5 = tt.splat %arg2 : i32 -> tensor<1024xi32> + %6 = arith.muli %5, %4 : tensor<1024xi32> + //%6: splat(%arg2)*range(2048, 3072); + //%6: offset = %arg2*2048, size = 1024, stride = %arg2*1 + %7 = arith.addi %3, %6 : tensor<1024xi32> + //%7: offset = %arg2*2048 + %0, size = 1024, stride = %arg2*1+1 + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + //source=%arg0: offset = %arg2*2048 + pid0, size = 1024, stride = %arg2*1+1 + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + //source=%arg1: offset = pid0, size = 1024, stride = 1 + %16 = tt.load %9 : tensor<1024x!tt.ptr> + tt.store %11, %16 : tensor<1024x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:[A-Za-z0-9_]+]]: memref, %[[ARG_1:[A-Za-z0-9_]+]]: memref, %[[VAL_0:[A-Za-z0-9_]+]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:[A-Za-z0-9_]+]]: memref {tt.tensor_kind = 1 : i32} +// CHECK-SAME: %[[ARG_4:[A-Za-z0-9_]+]]: i32, %[[ARG_5:[A-Za-z0-9_]+]]: i32, %[[ARG_6:[A-Za-z0-9_]+]]: i32, %[[ARG_7:[A-Za-z0-9_]+]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:[A-Za-z0-9_]+]]: i32, %[[ARG_10:[A-Za-z0-9_]+]]: i32 +// CHECK-SAME: ) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2048 : index +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[ARG_8]] : i32 to index +// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[ARG_4]] : i32 to index +// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_8]] : index +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_9]], %[[VAL_11]] : index +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_10]], %[[VAL_6]] : index +// CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [1024], strides: {{\[}}%[[VAL_14]]] : memref to memref<1024xbf16, strided<[?], offset: ?>> +// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: [1] : memref to memref<1024xbf16, strided<[1], offset: ?>> +// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<1024xbf16> +// CHECK: memref.copy %[[VAL_15]], %[[VAL_18]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16> +// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<1024xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_17]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_nested.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_nested.mlir new file mode 100644 index 000000000..5352f4728 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_nested.mlir @@ -0,0 +1,65 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : i32 + ) + { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + // offset = 0, size = 4, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + // offset = [0,0], size = [4,1], stride = [1,0] + %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [1,0] + %arg1splat = tt.splat %arg1 : i32 -> tensor<4x256xi32> + %offset3 = arith.addi %2, %arg1splat : tensor<4x256xi32> + // offset = [%arg1,0], size = [4,256], stride = [1,0] + %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + // offset = 0, size = 256, stride = 1 + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + // offset = [0,0], size = [1,256], stride = [0,1] + %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,1] + %6 = arith.constant 5 : i32 + %splat6 = tt.splat %6 : i32 -> tensor<4x256xi32> + %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> + // offset = [0,0], size = [4,256], stride = [0,5] + %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> + // offset = [%arg1, 0], size = [4, 256], stride = [1, 5] + %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg1, 0], size = [4, 256], stride = [1, 5] + %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> + %12 = tt.addptr %9, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg1+%arg1, 0], size = [4, 256], stride = [2, 10] + %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> + %14 = arith.addf %10, %13 : tensor<4x256xbf16> + %16 = tt.addptr %12, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source: %arg0, offset = [%arg1+%arg1+%arg1, 0], size = [4, 256], stride = [3, 15] + tt.store %16, %14 : tensor<4x256x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 2 : i32}, +// CHECK-SAME: %[[VAL_1:.*]]: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_1]] : i32 to index +// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, 5] : memref to memref<4x256xbf16, strided<[1, 5], offset: ?>> +// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[VAL_9]], %[[VAL_10]] : memref<4x256xbf16, strided<[1, 5], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<4x256xbf16> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_8]], %[[VAL_8]] : index +// CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_14]]], sizes: [4, 256], strides: [2, 10] : memref to memref<4x256xbf16, strided<[2, 10], offset: ?>> +// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<4x256xbf16> +// CHECK: memref.copy %[[VAL_15]], %[[VAL_16]] : memref<4x256xbf16, strided<[2, 10], offset: ?>> to memref<4x256xbf16> +// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<4x256xbf16> +// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_17]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_11]] : tensor<4x256xbf16>) { +// CHECK: ^bb0(%[[VAL_19:.*]]: bf16, %[[VAL_20:.*]]: bf16, %[[VAL_21:.*]]: bf16): +// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_19]], %[[VAL_20]] : bf16 +// CHECK: linalg.yield %[[VAL_22]] : bf16 +// CHECK: } -> tensor<4x256xbf16> +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_14]], %[[VAL_8]] : index +// CHECK: %[[VAL_28:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [4, 256], strides: [3, 15] : memref to memref<4x256xbf16, strided<[3, 15], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_29:.*]] in writable %[[VAL_28]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir new file mode 100644 index 000000000..d4e59257c --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir @@ -0,0 +1,43 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg %s | FileCheck %s +// TODO: expand this example to 3D +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr + ) + { + %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> + // offset = [512] size = 256, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> + // offset = [512,0], size = [256,1], stride = [1,0] + %2 = tt.broadcast %1 : tensor<256x1xi32> -> tensor<256x128xi32> + // offset = [512,0], size = [256,128], stride = [1,0] + %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> + // offset = 1024, size = 128, stride = 1 + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + // offset = [0,1024], size = [1,128], stride = [0,1] + %7 = tt.broadcast %6 : tensor<1x128xi32> -> tensor<256x128xi32> + // offset = [0,1024], size = [256,128], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : i32 -> tensor<256x128xi32> + %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> + // offset = [0,6144], size = [256,128], stride = [0,6] + %14 = arith.addi %2, %scale7 : tensor<256x128xi32> + // offset = [512,6144], size = [256,128], stride = [1,6] + %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x128x!tt.ptr> + %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> + %19 = tt.load %18 : tensor<256x128x!tt.ptr> + tt.store %18, %19 : tensor<256x128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref {tt.tensor_kind = 2 : i32}, +// CHECK-SAME: %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: {{\[}}1, 6] : memref to memref<256x128xbf16, strided<[1, 6], offset: 6656>> +// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<256x128xbf16> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<256x128xbf16, strided<[1, 6], offset: 6656>> to memref<256x128xbf16> +// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<256x128xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in writable %[[VAL_7]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir new file mode 100644 index 000000000..84bd0d522 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir @@ -0,0 +1,66 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + // source = arg1, offset = %1, size = 1, strides = 0 + %3 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> + // source = arg1, offset = %1, size = 1024, strides = 0 + %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<1024x!tt.ptr> -> tensor<1024x1x!tt.ptr> + // source = arg1, offset = [%1, 0], size = [1024, 1], strides = [0, 0] + %5 = tt.broadcast %4 : tensor<1024x1x!tt.ptr> -> tensor<1024x1024x!tt.ptr> + // source = arg1, offset = [%1, 0], size = [1024, 1024], strides = [0, 0] + %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<1024xi32> -> tensor<1x1024xi32> + // offset = [0, 0], size = [1, 1024], strides = [0, 1] + %8 = tt.broadcast %7 : tensor<1x1024xi32> -> tensor<1024x1024xi32> + // offset = [0, 0], size = [1024, 1024], strides = [0, 1] + %9 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<1024xi32> -> tensor<1024x1xi32> + // offset = [0, 0], size = [1024, 1], strides = [1, 0] + %11 = tt.broadcast %10 : tensor<1024x1xi32> -> tensor<1024x1024xi32> + // offset = [0, 0], size = [1024, 1024], strides = [1, 0] + %12 = arith.addi %8, %11 : tensor<1024x1024xi32> + // offset = [0, 0], size = [1024, 1024], strides = [1, 1] + %13 = tt.addptr %5, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> + // source = arg1, offset = [pid * %arg2, 0], size = [1024, 1024], strides = [1, 1] + %14 = tt.load %13 : tensor<1024x1024x!tt.ptr> + %17 = math.exp %14 : tensor<1024x1024xf32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = arg0, offset = pid+arg3, size = 1, strides = 0 + %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> + // source = arg0, offset = pid+arg3, size = 1024, strides = 0 + %21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<1024x!tt.ptr> -> tensor<1024x1x!tt.ptr> + // source = arg0, offset = [pid+arg3, 0], size = [1024, 1], strides = [0, 0] + %22 = tt.broadcast %21 : tensor<1024x1x!tt.ptr> -> tensor<1024x1024x!tt.ptr> + // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [0, 0] + %23 = tt.addptr %22, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> + // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [1, 1] + tt.store %23, %17 : tensor<1024x1024x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %[[VAL_1:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, +// CHECK-SAME: %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32, %[[ARG_12:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_10]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index +// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024, 1024], strides: [1, 1] : memref to memref<1024x1024xf32, strided<[1, 1], offset: ?>> +// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024x1024xf32> +// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<1024x1024xf32, strided<[1, 1], offset: ?>> to memref<1024x1024xf32> +// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<1024x1024xf32> +// CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_12]] : tensor<1024x1024xf32>) outs(%[[VAL_12]] : tensor<1024x1024xf32>) { +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = math.exp %[[VAL_14]] : f32 +// CHECK: linalg.yield %[[VAL_16]] : f32 +// CHECK: } -> tensor<1024x1024xf32> +// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_10]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index +// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024, 1024], strides: [1, 1] : memref to memref<1024x1024xf32, strided<[1, 1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_20:.*]] in writable %[[VAL_19]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir new file mode 100644 index 000000000..0734f88bc --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir @@ -0,0 +1,32 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) { + %0 = tt.addptr %arg0, %arg2 : !tt.ptr, i32 + %1 = tt.addptr %arg1, %arg2 : !tt.ptr, i32 + %10 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: !tt.ptr + tt.store %1, %10 : !tt.ptr + tt.return + } +} + +// CHECK: module { +// CHECK: func.func @kernel(%arg0: memref, %arg1: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %[[VAL_2:.*]]: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_2]] : i32 to index +// CHECK: %[[VAL_5:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_4]]], sizes: [1], strides: [1] : memref to memref<1xbf16, strided<[1], offset: ?>> +// CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<1xbf16> +// CHECK: memref.copy %[[VAL_5]], %[[VAL_6]] : memref<1xbf16, strided<[1], offset: ?>> to memref<1xbf16> +// CHECK: %[[VAL_7:.*]] = bufferization.to_tensor %[[VAL_6]] restrict writable : memref<1xbf16> +// CHECK: %[[VAL_8:.*]] = tensor.extract %[[VAL_7]]{{\[}}%[[VAL_3]]] : tensor<1xbf16> +// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<1xbf16> +// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_8]] : bf16) outs(%[[VAL_9]] : tensor<1xbf16>) -> tensor<1xbf16> +// CHECK: %[[VAL_11:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [%[[VAL_4]]], sizes: [1], strides: [1] : memref to memref<1xbf16, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_10]] in writable %[[VAL_11]] : (tensor<1xbf16>, memref<1xbf16, strided<[1], offset: ?>>) -> () +// CHECK: return +// CHECK: } +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir new file mode 100644 index 000000000..abcd66e16 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir @@ -0,0 +1,56 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + // source = arg1, offset = %1, size = 1, strides = 0 + %3 = arith.muli %0, %arg3 : i32 + %4 = tt.addptr %2, %3 : !tt.ptr, i32 + // source = arg1, offset = %1+%3, size = 1, strides = 0 + %5 = arith.muli %0, %arg4 : i32 + %6 = tt.addptr %4, %5 : !tt.ptr, i32 + // source = arg1, offset = %1+%3+%5, size = 1, strides = 0 + %7 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %8 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> + // source = arg1, offset = %1, size = 1024, strides = 0 + %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = arg1, offset = %1+%3+%5, size = 1024, strides = 1 + %10 = tt.load %9 : tensor<1024x!tt.ptr> + %17 = math.exp %10 : tensor<1024xf32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = arg0, offset = %18, size = 1, strides = 0 + %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> + // source = arg0, offset = %18, size = 1024, strides = 0 + %21 = tt.addptr %20, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = arg0, offset = %18, size = 1024, strides = 1 + tt.store %21, %17 : tensor<1024x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %[[VAL_1:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, +// CHECK-SAME: %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32, %[[ARG_12:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_10]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index +// CHECK: %[[VAL_10:.*]] = arith.muli %[[ARG_10]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : i32 to index +// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_9]], %[[VAL_11]] : index +// CHECK: %[[VAL_13:.*]] = arith.muli %[[ARG_10]], %[[VAL_4]] : i32 +// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i32 to index +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_14]] : index +// CHECK: %[[VAL_16:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_15]]], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1], offset: ?>> +// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<1024xf32> +// CHECK: memref.copy %[[VAL_16]], %[[VAL_17]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> +// CHECK: %[[VAL_18:.*]] = bufferization.to_tensor %[[VAL_17]] restrict writable : memref<1024xf32> +// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_18]] : tensor<1024xf32>) outs(%[[VAL_18]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): +// CHECK: %[[VAL_22:.*]] = math.exp %[[VAL_20]] : f32 +// CHECK: linalg.yield %[[VAL_22]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL_25:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_11]]], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_19:.*]] in writable %[[VAL_25]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir new file mode 100644 index 000000000..16dd6c554 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir @@ -0,0 +1,46 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + // source = %arg1, offset = %1, size = 1, strides = 0 + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %4 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> + // source = %arg1, offset = %1, size = 1024, strides = 0 + %5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = %arg1, offset = %1, size = 1024, strides = 1 + %8 = tt.load %5 : tensor<1024x!tt.ptr> + %17 = math.exp %8 : tensor<1024xf32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = %arg0, offset = %18, size = 1, strides = 0 + %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> + // source = %arg0, offset = %18, size = 1024, strides = 0 + %21 = tt.addptr %20, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = %arg0, offset = %18, size = 1024, strides = 1 + tt.store %21, %17 : tensor<1024x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %[[VAL_1:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, +// CHECK-SAME: %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32, %[[ARG_12:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_10]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index +// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1], offset: ?>> +// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024xf32> +// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> +// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<1024xf32> +// CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_12]] : tensor<1024xf32>) outs(%[[VAL_12]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = math.exp %[[VAL_14]] : f32 +// CHECK: linalg.yield %[[VAL_16]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_10]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index +// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_20:.*]] in writable %[[VAL_19]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir new file mode 100644 index 000000000..f6f7e3cfa --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir @@ -0,0 +1,57 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %3 = tt.splat %2 : !tt.ptr -> tensor<128x128x!tt.ptr> + // source = %arg1, offset = [%1, 0], size = [128, 128], strides = [0, 0] + %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %6 = tt.broadcast %5 : tensor<1x128xi32> -> tensor<128x128xi32> + // offset = [0, 0], size = [128, 128], strides = [0, 1] + %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> + // offset = 128, size = 128, strides = 1 + %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %9 = tt.broadcast %8 : tensor<128x1xi32> -> tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 0] + %10 = arith.addi %6, %9 : tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 1] + %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] + %12 = tt.load %11 : tensor<128x128x!tt.ptr> + %17 = math.exp %12 : tensor<128x128xf32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = arg0, offset = %18, size = 1, strides = 0 + %20 = tt.splat %19 : !tt.ptr -> tensor<128x128x!tt.ptr> + // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] + %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] + tt.store %21, %17 : tensor<128x128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %[[VAL_1:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, +// CHECK-SAME: %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32, %[[ARG_12:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_8:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARG_10]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_8]] : index +// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xf32, strided<[1, 1], offset: ?>> +// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<128x128xf32> +// CHECK: memref.copy %[[VAL_12]], %[[VAL_13]] : memref<128x128xf32, strided<[1, 1], offset: ?>> to memref<128x128xf32> +// CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<128x128xf32> +// CHECK: %[[VAL_15:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_14]] : tensor<128x128xf32>) outs(%[[VAL_14]] : tensor<128x128xf32>) { +// CHECK: ^bb0(%[[VAL_16:.*]]: f32, %[[VAL_17:.*]]: f32): +// CHECK: %[[VAL_18:.*]] = math.exp %[[VAL_16]] : f32 +// CHECK: linalg.yield %[[VAL_18]] : f32 +// CHECK: } -> tensor<128x128xf32> +// CHECK: %[[VAL_19:.*]] = arith.muli %[[ARG_10]], %[[VAL_3]] : i32 +// CHECK: %[[VAL_20:.*]] = arith.index_cast %[[VAL_19]] : i32 to index +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index +// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xf32, strided<[1, 1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_splat_addptr.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_splat_addptr.mlir new file mode 100644 index 000000000..32f660033 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_splat_addptr.mlir @@ -0,0 +1,37 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + // CHECK-LABEL: func @addptr_splat_addptr + tt.func public @addptr_splat_addptr(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c16_i64 = arith.constant 16 : i64 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_tensor_ptr %arg1, [%c4_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %1] {order = array} : !tt.ptr> + %3 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + // CHECK: %[[OFFSET1:.*]] = arith.index_cast %[[BASE1:.*]] : i32 to index + %6 = scf.for %arg3 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg4 = %2) -> !tt.ptr> : i32 { + %7 = tt.load %arg4 : !tt.ptr> + %8 = tt.advance %arg4, [%c0_i32, %c32_i32] : > + %9 = tt.reshape %7 : tensor<4x8xf16> -> tensor<32xf16> + // CHECK: %[[IV_OFFSET:.*]] = arith.muli %[[IV:.*]], %c32_i32 : i32 + // CHECK: %[[OFFSET2:.*]] = arith.index_cast %[[IV_OFFSET]] : i32 to index + // CHECK: %[[OFFSET:.*]] = arith.addi %[[OFFSET1]], %[[OFFSET2]] : index + // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [%[[OFFSET]]], sizes: [32], strides: [1] : memref to memref<32xf16, strided<[1], offset: ?>> + // CHECK: bufferization.materialize_in_destination %[[RESHAPE:.*]] in writable %[[CAST]] : (tensor<32xf16>, memref<32xf16, strided<[1], offset: ?>>) -> () + %10 = arith.muli %arg3, %c32_i32 : i32 + %11 = tt.addptr %5, %10 : !tt.ptr, i32 + %12 = tt.splat %11 : !tt.ptr -> tensor<32x!tt.ptr> + %13 = tt.addptr %12, %3 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %13, %9 : tensor<32x!tt.ptr> + scf.yield %8 : !tt.ptr> + } + tt.return + } +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/addptr_with_int_to_ptr_source.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_with_int_to_ptr_source.mlir new file mode 100644 index 000000000..cf1ea35bb --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/addptr_with_int_to_ptr_source.mlir @@ -0,0 +1,41 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-unstructure --bubble-up-operation --discrete-mask-access-conversion --triton-to-hivm "--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False" %s | FileCheck %s + +module { + tt.func @addptr_with_int_to_ptr_source(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %c0_i64 = arith.constant 0 : i64 + %0 = tt.load %arg0 : !tt.ptr + %1 = tt.int_to_ptr %0 : i64 -> !tt.ptr + %2 = tt.addptr %1, %c0_i64 : !tt.ptr, i64 + %3 = tt.load %2 : !tt.ptr + tt.store %arg1, %3 : !tt.ptr + tt.return + } +} + +// CHECK-LABEL: func.func @addptr_with_int_to_ptr_source +// CHECK: %[[arg0:.*]]: memref {tt.tensor_kind = 0 : i32}, +// CHECK: %[[arg1:.*]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK: %[[VAL_10:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1xi64> +// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_2:.*]][0] [1] [1] : memref to memref<1xi64, strided<[1]>> +// CHECK: memref.copy %[[VAL_12]], %[[VAL_11]] : memref<1xi64, strided<[1]>> to memref<1xi64> +// CHECK: %[[VAL_13:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<1xi64> +// CHECK: %[[VAL_14:.*]] = tensor.extract %[[VAL_13]]{{\[}}%[[VAL_10]]] : tensor<1xi64> +// CHECK: %[[VAL_15:.*]] = arith.constant 8 : i64 +// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_17:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_17]], %[[VAL_15]] : i64 +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_14]], %[[VAL_18]] : i64 +// CHECK: %[[VAL_20:.*]] = hivm.hir.pointer_cast(%[[VAL_19]]) {{\[}}%[[VAL_16]]] : memref +// CHECK: annotation.mark %[[VAL_20]] {address_space = #hivm.address_space} : memref +// CHECK: %[[VAL_21:.*]] = memref.reinterpret_cast %[[VAL_20]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi64, strided<[1]>> +// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<1xi64> +// CHECK: memref.copy %[[VAL_21]], %[[VAL_22]] : memref<1xi64, strided<[1]>> to memref<1xi64> +// CHECK: %[[VAL_23:.*]] = bufferization.to_tensor %[[VAL_22]] restrict writable : memref<1xi64> +// CHECK: %[[VAL_24:.*]] = tensor.extract %[[VAL_23]]{{\[}}%[[VAL_10]]] : tensor<1xi64> +// CHECK: %[[VAL_25:.*]] = tensor.empty() : tensor<1xi64> +// CHECK: %[[VAL_26:.*]] = linalg.fill ins(%[[VAL_24]] : i64) outs(%[[VAL_25]] : tensor<1xi64>) -> tensor<1xi64> +// CHECK: %[[VAL_27:.*]] = memref.reinterpret_cast %[[VAL_3:.*]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi64, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_26]] in writable %[[VAL_27]] : (tensor<1xi64>, memref<1xi64, strided<[1]>>) -> () +// CHECK: return +// CHECK: } \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir new file mode 100644 index 000000000..bd468cfb3 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir @@ -0,0 +1,40 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %a : !tt.ptr, + %b : !tt.ptr + ) -> () { + // offset calculations + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // a pointer + %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // b pointer + %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> + %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %am = tt.load %9 : tensor<1024x!tt.ptr> + %bm = tt.load %19 : tensor<1024x!tt.ptr> + %5 = arith.addi %am, %bm : tensor<1024xi32> + tt.store %19, %5 : tensor<1024x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 2 : i32}, +// CHECK-SAME: %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_5:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xi32, strided<[1]>> +// CHECK: %[[VAL_6:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xi32, strided<[1]>> +// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<1024xi32> +// CHECK: memref.copy %[[VAL_5]], %[[VAL_7]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> +// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_7]] restrict writable : memref<1024xi32> +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<1024xi32> +// CHECK: memref.copy %[[VAL_6]], %[[VAL_9]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> +// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<1024xi32> +// CHECK: %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_8]], %[[VAL_10]] : tensor<1024xi32>, tensor<1024xi32>) outs(%[[VAL_8]] : tensor<1024xi32>) { +// CHECK: ^bb0(%[[VAL_12:.*]]: i32, %[[VAL_13:.*]]: i32, %[[VAL_14:.*]]: i32): +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : i32 +// CHECK: linalg.yield %[[VAL_15]] : i32 +// CHECK: } -> tensor<1024xi32> +// CHECK: bufferization.materialize_in_destination %[[VAL_16:.*]] in writable %[[VAL_6]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/barrier.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/barrier.mlir new file mode 100644 index 000000000..392979d43 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/barrier.mlir @@ -0,0 +1,13 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +tt.func @test_barrier(%arg0 : !tt.ptr, %arg1 : !tt.ptr, %arg2 : i32) { + %0 = tt.addptr %arg0, %arg2 : !tt.ptr, i32 + %1 = tt.addptr %arg1, %arg2 : !tt.ptr, i32 + %10 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: !tt.ptr + tt.store %1, %10 : !tt.ptr + gpu.barrier + tt.return +} + +//CHECK-LABEL: test_barrier +//CHECK: gpu.barrier diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/bitcast.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/bitcast.mlir new file mode 100644 index 000000000..78d95c01b --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/bitcast.mlir @@ -0,0 +1,45 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func @kernel(%a : !tt.ptr, %b : !tt.ptr) -> () { + // offset calculations + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + + // a pointer + %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + + // b pointer + %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> + %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + + %am = tt.load %9 : tensor<1024x!tt.ptr> + + // cast result before doing float add + %am_bitcast = tt.bitcast %am : tensor<1024xi32> -> tensor<1024xf32> + + + tt.store %19, %am_bitcast : tensor<1024x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: [[RC_:%.+]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1]{{.*}} : memref to memref<1024xi32, strided<[1]>> +// CHECK: [[RC_0_:%.+]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1]{{.*}} : memref to memref<1024xf32, strided<[1]>> +// CHECK: [[ALLOC_:%.+]] = memref.alloc() : memref<1024xi32> +// CHECK: memref.copy [[RC_]], [[ALLOC_]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> +// CHECK: [[VAR_0_:%.+]] = bufferization.to_tensor [[ALLOC_]] restrict writable : memref<1024xi32> +// CHECK: [[VAR_1_:%.+]] = tensor.empty() : tensor<1024xf32> +// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_0_]] : tensor<1024xi32>) outs([[VAR_1_]] : tensor<1024xf32>) { +// CHECK: ^bb0(%in: i32, %out: f32): +// CHECK: [[VAR_5_:%.+]] = arith.bitcast %in : i32 to f32 +// CHECK: linalg.yield [[VAR_5_]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[RC_0_]] +// CHECK: return +// CHECK: } + + diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/block_ptr_advance.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/block_ptr_advance.mlir new file mode 100644 index 000000000..841a21b1e --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/block_ptr_advance.mlir @@ -0,0 +1,129 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func public @matmul_kernel_with_block_pointers_01234567891011(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32) { + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 0.000000e+00 : bf16 + %c256_i32 = arith.constant 256 : i32 + %0 = arith.extsi %arg3 : i32 to i64 + %1 = arith.extsi %arg5 : i32 to i64 + %2 = arith.extsi %arg6 : i32 to i64 + %3 = arith.extsi %arg7 : i32 to i64 + %4 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %3], [%arg12, %c0_i32] {order = array} : > + %5 = tt.advance %4, [%c0_i32, %c64_i32] : > + %6 = tt.splat %cst : bf16 -> tensor<128x64xbf16> + %7:3 = scf.for %arg14 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg15 = %6, %arg16 = %5, %arg17 = %4) -> (tensor<128x64xbf16>, !tt.ptr>, !tt.ptr>) : i32 { + %13 = tt.load %arg16 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> + %14 = tt.load %arg17 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> + %15 = arith.addf %13, %14 : tensor<128x64xbf16> + %16 = arith.addf %arg15, %15 : tensor<128x64xbf16> + %17 = tt.advance %arg16, [%c0_i32, %c64_i32] : > + %18 = tt.advance %arg17, [%c64_i32, %c0_i32] : > + scf.yield %16, %17, %18 : tensor<128x64xbf16>, !tt.ptr>, !tt.ptr> + } + %8 = arith.extsi %arg10 : i32 to i64 + %9 = arith.extsi %arg11 : i32 to i64 + %10 = arith.extsi %arg4 : i32 to i64 + %11 = arith.muli %arg13, %c256_i32 : i32 + %12 = tt.make_tensor_ptr %arg2, [%0, %10], [%8, %9], [%arg12, %11] {order = array} : > + tt.store %12, %7#0 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr> + tt.return + } +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_kernel_with_block_pointers_01234567891011( +// CHECK-SAME: %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref, %[[VAL_2:.*]]: memref, %[[VAL_3:.*]]: memref, %[[VAL_4:.*]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32, %[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: i32, %[[VAL_12:.*]]: i32, %[[VAL_13:.*]]: i32, %[[VAL_14:.*]]: i32, %[[VAL_15:.*]]: i32, %[[VAL_16:.*]]: i32, %[[VAL_17:.*]]: i32, %[[VAL_18:.*]]: i32, %[[VAL_19:.*]]: i32, %[[VAL_20:.*]]: i32, %[[VAL_21:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_22:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_23:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_24:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_25:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_26:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_27:.*]] = arith.constant 64 : i32 +// CHECK: %[[VAL_28:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[VAL_29:.*]] = tensor.empty() : tensor<128x64xbf16> +// CHECK: %[[VAL_30:.*]] = linalg.fill ins(%[[VAL_28]] : bf16) outs(%[[VAL_29]] : tensor<128x64xbf16>) -> tensor<128x64xbf16> +// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_14]] : i32 to index +// CHECK: %[[VAL_32:.*]] = arith.index_cast %[[VAL_8]] : i32 to index +// CHECK: %[[VAL_33:.*]] = arith.index_cast %[[VAL_9]] : i32 to index +// CHECK: %[[VAL_34:.*]] = arith.muli %[[VAL_31]], %[[VAL_32]] : index +// CHECK: %[[VAL_35:.*]] = arith.index_cast %[[VAL_5]] : i32 to index +// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_7]] : i32 to index +// CHECK: %[[VAL_37:.*]] = arith.muli %[[VAL_33]], %[[VAL_24]] : index +// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_34]], %[[VAL_37]] : index +// CHECK: %[[VAL_39:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_38]]], sizes: [128, 64], strides: {{\[}}%[[VAL_32]], %[[VAL_33]]] : memref to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: %[[VAL_40:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_34]]], sizes: [128, 64], strides: {{\[}}%[[VAL_32]], %[[VAL_33]]] : memref to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: %[[VAL_41:.*]]:7 = scf.for %[[VAL_42:.*]] = %[[VAL_26]] to %[[VAL_7]] step %[[VAL_27]] iter_args(%[[VAL_43:.*]] = %[[VAL_30]], %[[VAL_44:.*]] = %[[VAL_39]], %[[VAL_45:.*]] = %[[VAL_40]], %[[VAL_46:.*]] = %[[VAL_38]], %[[VAL_47:.*]] = %[[VAL_23]], %[[VAL_48:.*]] = %[[VAL_34]], %[[VAL_49:.*]] = %[[VAL_23]]) -> (tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index, index, index) : i32 { +// CHECK: %[[VAL_50:.*]] = memref.alloc() : memref<128x64xbf16> +// CHECK: %[[VAL_51:.*]] = arith.divsi %[[VAL_46]], %[[VAL_32]] : index +// CHECK: %[[VAL_52:.*]] = arith.subi %[[VAL_35]], %[[VAL_51]] : index +// CHECK: %[[VAL_53:.*]] = arith.maxsi %[[VAL_52]], %[[VAL_23]] : index +// CHECK: %[[VAL_54:.*]] = arith.minsi %[[VAL_53]], %[[VAL_22]] : index +// CHECK: %[[VAL_55:.*]] = arith.remsi %[[VAL_46]], %[[VAL_32]] : index +// CHECK: %[[VAL_56:.*]] = arith.divsi %[[VAL_55]], %[[VAL_33]] : index +// CHECK: %[[VAL_57:.*]] = arith.subi %[[VAL_36]], %[[VAL_56]] : index +// CHECK: %[[VAL_58:.*]] = arith.maxsi %[[VAL_57]], %[[VAL_23]] : index +// CHECK: %[[VAL_59:.*]] = arith.minsi %[[VAL_58]], %[[VAL_24]] : index +// CHECK: %[[VAL_60:.*]] = memref.subview %[[VAL_44]][0, 0] {{\[}}%[[VAL_54]], %[[VAL_59]]] [1, 1] : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref> +// CHECK: %[[VAL_61:.*]] = memref.subview %[[VAL_50]][0, 0] {{\[}}%[[VAL_54]], %[[VAL_59]]] [1, 1] : memref<128x64xbf16> to memref> +// CHECK: memref.copy %[[VAL_60]], %[[VAL_61]] : memref> to memref> +// CHECK: %[[VAL_62:.*]] = bufferization.to_tensor %[[VAL_50]] restrict writable : memref<128x64xbf16> +// CHECK: %[[VAL_63:.*]] = memref.alloc() : memref<128x64xbf16> +// CHECK: %[[VAL_64:.*]] = arith.divsi %[[VAL_48]], %[[VAL_32]] : index +// CHECK: %[[VAL_65:.*]] = arith.subi %[[VAL_35]], %[[VAL_64]] : index +// CHECK: %[[VAL_66:.*]] = arith.maxsi %[[VAL_65]], %[[VAL_23]] : index +// CHECK: %[[VAL_67:.*]] = arith.minsi %[[VAL_66]], %[[VAL_22]] : index +// CHECK: %[[VAL_68:.*]] = arith.remsi %[[VAL_48]], %[[VAL_32]] : index +// CHECK: %[[VAL_69:.*]] = arith.divsi %[[VAL_68]], %[[VAL_33]] : index +// CHECK: %[[VAL_70:.*]] = arith.subi %[[VAL_36]], %[[VAL_69]] : index +// CHECK: %[[VAL_71:.*]] = arith.maxsi %[[VAL_70]], %[[VAL_23]] : index +// CHECK: %[[VAL_72:.*]] = arith.minsi %[[VAL_71]], %[[VAL_24]] : index +// CHECK: %[[VAL_73:.*]] = memref.subview %[[VAL_45]][0, 0] {{\[}}%[[VAL_67]], %[[VAL_72]]] [1, 1] : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref> +// CHECK: %[[VAL_74:.*]] = memref.subview %[[VAL_63]][0, 0] {{\[}}%[[VAL_67]], %[[VAL_72]]] [1, 1] : memref<128x64xbf16> to memref> +// CHECK: memref.copy %[[VAL_73]], %[[VAL_74]] : memref> to memref> +// CHECK: %[[VAL_75:.*]] = bufferization.to_tensor %[[VAL_63]] restrict writable : memref<128x64xbf16> +// CHECK: %[[VAL_76:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_62]], %[[VAL_75]] : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%[[VAL_62]] : tensor<128x64xbf16>) { +// CHECK: ^bb0(%[[VAL_77:.*]]: bf16, %[[VAL_78:.*]]: bf16, %[[VAL_79:.*]]: bf16): +// CHECK: %[[VAL_80:.*]] = arith.addf %[[VAL_77]], %[[VAL_78]] : bf16 +// CHECK: linalg.yield %[[VAL_80]] : bf16 +// CHECK: } -> tensor<128x64xbf16> +// CHECK: %[[VAL_81:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_43]], %[[VAL_76]] : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%[[VAL_43]] : tensor<128x64xbf16>) { +// CHECK: ^bb0(%[[VAL_82:.*]]: bf16, %[[VAL_83:.*]]: bf16, %[[VAL_84:.*]]: bf16): +// CHECK: %[[VAL_85:.*]] = arith.addf %[[VAL_82]], %[[VAL_83]] : bf16 +// CHECK: linalg.yield %[[VAL_85]] : bf16 +// CHECK: } -> tensor<128x64xbf16> +// CHECK: %[[VAL_86:.*]] = arith.muli %[[VAL_33]], %[[VAL_24]] : index +// CHECK: %[[VAL_87:.*]] = arith.addi %[[VAL_86]], %[[VAL_47]] : index +// CHECK: %[[VAL_88:.*]] = arith.addi %[[VAL_46]], %[[VAL_87]] : index +// CHECK: %[[VAL_89:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_88]]], sizes: [128, 64], strides: {{\[}}%[[VAL_32]], %[[VAL_33]]] : memref to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: %[[VAL_90:.*]] = arith.muli %[[VAL_32]], %[[VAL_24]] : index +// CHECK: %[[VAL_91:.*]] = arith.addi %[[VAL_90]], %[[VAL_48]] : index +// CHECK: %[[VAL_92:.*]] = arith.addi %[[VAL_91]], %[[VAL_49]] : index +// CHECK: %[[VAL_93:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_92]]], sizes: [128, 64], strides: {{\[}}%[[VAL_32]], %[[VAL_33]]] : memref to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: scf.yield %[[VAL_81]], %[[VAL_89]], %[[VAL_93]], %[[VAL_88]], %[[VAL_23]], %[[VAL_92]], %[[VAL_23]] : tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index, index, index +// CHECK: } +// CHECK: %[[VAL_94:.*]] = arith.muli %[[VAL_15]], %[[VAL_25]] : i32 +// CHECK: %[[VAL_95:.*]] = arith.index_cast %[[VAL_94]] : i32 to index +// CHECK: %[[VAL_96:.*]] = arith.index_cast %[[VAL_12]] : i32 to index +// CHECK: %[[VAL_97:.*]] = arith.index_cast %[[VAL_13]] : i32 to index +// CHECK: %[[VAL_98:.*]] = arith.muli %[[VAL_31]], %[[VAL_96]] : index +// CHECK: %[[VAL_99:.*]] = arith.muli %[[VAL_95]], %[[VAL_97]] : index +// CHECK: %[[VAL_100:.*]] = arith.index_cast %[[VAL_6]] : i32 to index +// CHECK: %[[VAL_101:.*]] = arith.addi %[[VAL_98]], %[[VAL_99]] : index +// CHECK: %[[VAL_102:.*]] = memref.reinterpret_cast %[[VAL_4]] to offset: {{\[}}%[[VAL_101]]], sizes: [128, 64], strides: {{\[}}%[[VAL_96]], %[[VAL_97]]] : memref to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: %[[VAL_103:.*]] = arith.divsi %[[VAL_101]], %[[VAL_96]] : index +// CHECK: %[[VAL_104:.*]] = arith.subi %[[VAL_35]], %[[VAL_103]] : index +// CHECK: %[[VAL_105:.*]] = arith.maxsi %[[VAL_104]], %[[VAL_23]] : index +// CHECK: %[[VAL_106:.*]] = arith.minsi %[[VAL_105]], %[[VAL_22]] : index +// CHECK: %[[VAL_107:.*]] = arith.remsi %[[VAL_101]], %[[VAL_96]] : index +// CHECK: %[[VAL_108:.*]] = arith.divsi %[[VAL_107]], %[[VAL_97]] : index +// CHECK: %[[VAL_109:.*]] = arith.subi %[[VAL_100]], %[[VAL_108]] : index +// CHECK: %[[VAL_110:.*]] = arith.maxsi %[[VAL_109]], %[[VAL_23]] : index +// CHECK: %[[VAL_111:.*]] = arith.minsi %[[VAL_110]], %[[VAL_24]] : index +// CHECK: %[[VAL_112:.*]] = tensor.extract_slice %[[VAL_113:.*]]#0[0, 0] {{\[}}%[[VAL_106]], %[[VAL_111]]] [1, 1] : tensor<128x64xbf16> to tensor +// CHECK: %[[VAL_114:.*]] = memref.subview %[[VAL_102]][0, 0] {{\[}}%[[VAL_106]], %[[VAL_111]]] [1, 1] : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref> +// CHECK: bufferization.materialize_in_destination %[[VAL_112]] in writable %[[VAL_114]] : (tensor, memref>) -> () +// CHECK: return +// CHECK: } + diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/broadcast.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/broadcast.mlir new file mode 100644 index 000000000..ef1ac4e70 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/broadcast.mlir @@ -0,0 +1,92 @@ +// RUN: triton-adapter-opt -split-input-file --triton-to-linalg=named-ops=true %s | FileCheck %s +module { + tt.func public @fn_broadcast_first_axis(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.reshape %3 : tensor<32xf32> -> tensor<1x4x8xf32> + %5 = tt.broadcast %4 : tensor<1x4x8xf32> -> tensor<128x4x8xf32> + %6 = tt.reshape %5 : tensor<128x4x8xf32> -> tensor<4096xf32> + %7 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> + %8 = tt.splat %arg1 : !tt.ptr -> tensor<4096x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4096x!tt.ptr>, tensor<4096xi32> + tt.store %9, %6 : tensor<4096x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @fn_broadcast_first_axis +// CHECK: %[[VAL_8:.*]] = arith.constant dense<4096> : tensor<1xi64> +// CHECK: %[[VAL_9:.*]] = arith.constant dense<[1, 4, 8]> : tensor<3xi64> +// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [32], strides: [1] : memref to memref<32xf32, strided<[1]>> +// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<32xf32> +// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<32xf32, strided<[1]>> to memref<32xf32> +// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<32xf32> +// CHECK: %[[VAL_13:.*]] = tensor.reshape %[[VAL_12]](%[[VAL_9]]) : (tensor<32xf32>, tensor<3xi64>) -> tensor<1x4x8xf32> +// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<128x4x8xf32> +// CHECK: %[[VAL_15:.*]] = tensor.collapse_shape %[[VAL_13]] {{\[}}[0, 1], [2]] : tensor<1x4x8xf32> into tensor<4x8xf32> +// CHECK: %[[VAL_16:.*]] = linalg.broadcast ins(%[[VAL_15]] : tensor<4x8xf32>) outs(%[[VAL_14]] : tensor<128x4x8xf32>) dimensions = [0] +// CHECK: %[[VAL_19:.*]] = tensor.reshape %[[VAL_16]](%[[VAL_8]]) : (tensor<128x4x8xf32>, tensor<1xi64>) -> tensor<4096xf32> +// CHECK: %[[VAL_20:.*]] = memref.reinterpret_cast [[ARG_1:%.+]] to offset: [0], sizes: [4096], strides: [1] : memref to memref<4096xf32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_20]] : (tensor<4096xf32>, memref<4096xf32, strided<[1]>>) -> () +// CHECK: return + + +// ----- +module { + tt.func public @fn_broadcast_middle_axis(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %3 = tt.load %2 : tensor<1024x!tt.ptr> + %4 = tt.reshape %3 : tensor<1024xf32> -> tensor<128x1x8xf32> + %5 = tt.broadcast %4 : tensor<128x1x8xf32> -> tensor<128x4x8xf32> + %6 = tt.reshape %5 : tensor<128x4x8xf32> -> tensor<4096xf32> + %7 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> + %8 = tt.splat %arg1 : !tt.ptr -> tensor<4096x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<4096x!tt.ptr>, tensor<4096xi32> + tt.store %9, %6 : tensor<4096x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @fn_broadcast_middle_axis +// CHECK: %[[VAL_8:.*]] = arith.constant dense<4096> : tensor<1xi64> +// CHECK: %[[VAL_9:.*]] = arith.constant dense<[128, 1, 8]> : tensor<3xi64> +// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1]>> +// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024xf32> +// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> +// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<1024xf32> +// CHECK: %[[VAL_13:.*]] = tensor.reshape %[[VAL_12]](%[[VAL_9]]) : (tensor<1024xf32>, tensor<3xi64>) -> tensor<128x1x8xf32> +// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<128x4x8xf32> +// CHECK: %[[VAL_15:.*]] = tensor.collapse_shape %[[VAL_13]] {{\[}}[0], [1, 2]] : tensor<128x1x8xf32> into tensor<128x8xf32> +// CHECK: %[[VAL_16:.*]] = linalg.broadcast ins(%[[VAL_15]] : tensor<128x8xf32>) outs(%[[VAL_14]] : tensor<128x4x8xf32>) dimensions = [1] +// CHECK: %[[VAL_19:.*]] = tensor.reshape %[[VAL_16]](%[[VAL_8]]) : (tensor<128x4x8xf32>, tensor<1xi64>) -> tensor<4096xf32> +// CHECK: %[[VAL_20:.*]] = memref.reinterpret_cast [[ARG_1:%.+]] to offset: [0], sizes: [4096], strides: [1] : memref to memref<4096xf32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_20]] : (tensor<4096xf32>, memref<4096xf32, strided<[1]>>) -> () +// CHECK: return + + +// ----- + +module { + // // CHECK-LABEL: func @fn_broadcast_two_axis + tt.func public @fn_broadcast_two_axis(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %0 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<1x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<1x!tt.ptr>, tensor<1xi32> + %3 = tt.load %2 : tensor<1x!tt.ptr> + %4 = tt.reshape %3 : tensor<1xf32> -> tensor<1x1xf32> + // CHECK: %[[TENSOR:.*]] = tensor.empty() : tensor<4x8xf32> + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[RESHAPE:.*]] [] : tensor<1x1xf32> into tensor + // CHECK: %[[BRC:.*]] = linalg.broadcast ins(%[[COLLAPSED]] : tensor) outs(%[[TENSOR]] : tensor<4x8xf32>) dimensions = [0, 1] + %5 = tt.broadcast %4 : tensor<1x1xf32> -> tensor<4x8xf32> + %6 = tt.reshape %5 : tensor<4x8xf32> -> tensor<32xf32> + %7 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %8 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %9 = tt.addptr %8, %7 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %9, %6 : tensor<32x!tt.ptr> + tt.return + } +} \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/cat.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/cat.mlir new file mode 100644 index 000000000..64faee822 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/cat.mlir @@ -0,0 +1,23 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +tt.func public @fn_npu_cat(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: !tt.ptr) attributes {noinline = false} { + %0 = tt.make_range {end = 8192 : i32, start = 0 : i32} : tensor<8192xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8192x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8192x!tt.ptr>, tensor<8192xi32> + %3 = tt.load %2 : tensor<8192x!tt.ptr> + %4 = tt.splat %arg2 : !tt.ptr -> tensor<8192x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<8192x!tt.ptr>, tensor<8192xi32> + %6 = tt.load %5 : tensor<8192x!tt.ptr> + %7 = tt.cat %3, %6 : tensor<8192xf32> -> tensor<16384xf32> + %8 = tt.make_range {end = 16384 : i32, start = 0 : i32} : tensor<16384xi32> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<16384x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<16384x!tt.ptr>, tensor<16384xi32> + tt.store %10, %7 : tensor<16384x!tt.ptr> + tt.return +} +//CHECK-LABEL: @fn_npu_cat +//CHECK-NOT: tt.cat +//CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ADDR0:.*]] restrict writable : memref<8192xf32> +//CHECK: %[[VAL1:.*]] = bufferization.to_tensor %[[ADDR1:.*]] restrict writable : memref<8192xf32> +//CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<16384xf32> +//CHECK: %[[INSERT0:.*]] = tensor.insert_slice %[[VAL0]] into %[[EMPTY]][0] [8192] [1] : tensor<8192xf32> into tensor<16384xf32> +//CHECK: %[[INSERT1:.*]] = tensor.insert_slice %[[VAL1]] into %[[INSERT0]][8192] [8192] [1] : tensor<8192xf32> into tensor<16384xf32> diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/clampf.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/clampf.mlir new file mode 100644 index 000000000..2c6361cdb --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/clampf.mlir @@ -0,0 +1,45 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// RUN: triton-adapter-opt --triton-to-linalg="named-ops=True" %s | FileCheck %s --check-prefix NAMED +module { + tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: f32 {tt.divisibility = 16 : i32}, %arg2: f32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.splat %arg1 : f32 -> tensor<32xf32> + %5 = tt.splat %arg2 : f32 -> tensor<32xf32> + %6 = tt.clampf %3, %4, %5, propagateNan = none : tensor<32xf32> + %7 = tt.splat %arg3 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %8, %6 : tensor<32x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: f32 {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: f32 {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[VAR_0:.*]] to offset: [0], sizes: [32], strides: [1] : memref to memref<32xf32, strided<[1]>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32> +// CHECK: memref.copy %[[REINTERPRET_CAST]], %[[ALLOC:.*]] : memref<32xf32, strided<[1]>> to memref<32xf32> +// CHECK: %[[VAR_10:.*]] = bufferization.to_tensor %[[ALLOC:.*]] restrict writable : memref<32xf32> +// CHECK: %[[VAR_11:.*]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[VAR_12:.*]] = linalg.fill ins(%[[VAL_1]] : f32) outs(%[[VAR_11]] : tensor<32xf32>) -> tensor<32xf32> +// CHECK: %[[VAR_14:.*]] = linalg.fill ins(%[[VAL_2]] : f32) outs(%[[VAR_11]] : tensor<32xf32>) -> tensor<32xf32> +// CHECK: %[[VAR_15:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAR_10]], %[[VAR_14]] : tensor<32xf32>, tensor<32xf32>) outs(%[[VAR_10]] : tensor<32xf32>) { +// CHECK: ^bb0(%[[VAR_16:.*]]: f32, %[[VAR_17:.*]]: f32, %[[VAR_18:.*]]: f32): +// CHECK: %[[VAR_19:.*]] = arith.minnumf %[[VAR_16]], %[[VAR_17]] : f32 +// CHECK: linalg.yield %[[VAR_19]] : f32 +// CHECK: } -> tensor<32xf32> +// CHECK: %[[VAR_20:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAR_12]], %[[VAR_15]] : tensor<32xf32>, tensor<32xf32>) outs(%[[VAR_12]] : tensor<32xf32>) { +// CHECK: ^bb0(%[[VAR_21:.*]]: f32, %[[VAR_22:.*]]: f32, %[[VAR_23:.*]]: f32): +// CHECK: %[[VAR_19]] = arith.maxnumf %[[VAR_21]], %[[VAR_22]] : f32 +// CHECK: linalg.yield %[[VAR_19]] : f32 +// CHECK: } -> tensor<32xf32> +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[VAL_3]] to offset: [0], sizes: [32], strides: [1] : memref to memref<32xf32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAR_20]] in writable %[[REINTERPRET_CAST_0]] : (tensor<32xf32>, memref<32xf32, strided<[1]>>) -> () +// CHECK: return +// CHECK: } + +// NAMED-LABEL: func.func @kernel +// NAMED-NOT: linalg.generic diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/control_flow.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/control_flow.mlir new file mode 100644 index 000000000..d5489eaf3 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/control_flow.mlir @@ -0,0 +1,69 @@ +// RUN: triton-adapter-opt --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --bubble-up-operation --triton-to-linalg %s | FileCheck %s +module { + tt.func public @grouped_gemm_triton_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: i32 {tt.divisibility = 16 : i32} , %arg4: i32 {tt.divisibility = 16 : i32} , %arg5: i32 {tt.divisibility = 16 : i32} , %arg6: !tt.ptr {tt.divisibility = 16 : i32} , %arg7: !tt.ptr {tt.divisibility = 16 : i32} , %arg8: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %c127_i32 = arith.constant 127 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32> + %c64_i64 = arith.constant 64 : i64 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x128xbf16> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xbf16> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<128> : tensor<32x128xi32> + %cst_3 = arith.constant dense<128> : tensor<64x128xi32> + %c128_i32 = arith.constant 128 : i32 + %cst_4 = arith.constant dense<2048> : tensor<32x1xi32> + %c3145728_i64 = arith.constant 3145728 : i64 + %cst_5 = arith.constant dense<2048> : tensor<64x1xi64> + %cst_6 = arith.constant dense<0> : tensor<32xi32> + %cst_7 = arith.constant dense<0> : tensor<64xi32> + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.addptr %arg8, %arg3 : !tt.ptr, i32 + %3 = tt.load %2 : !tt.ptr + %4 = arith.extsi %0 : i32 to i64 + %5 = arith.cmpi sge, %4, %3 : i64 + cf.cond_br %5, ^bb1, ^bb2 + ^bb1: // 2 preds: ^bb0, ^bb2 + tt.return + ^bb2: // pred: ^bb0 + %6 = scf.for %arg9 = %c0_i32 to %arg3 step %c1_i32 iter_args(%arg10 = %c0_i32) -> (i32) : i32 { + %86 = tt.addptr %arg8, %arg9 : !tt.ptr, i32 + %87 = tt.load %86 : !tt.ptr + %88 = arith.cmpi sge, %4, %87 : i64 + %89 = arith.select %88, %arg9, %arg10 : i32 + scf.yield %89 : i32 + } + %7 = tt.addptr %arg8, %6 : !tt.ptr, i32 + %8 = tt.load %7 : !tt.ptr + %21 = arith.cmpi eq, %8, %c0_i64 : i64 + cf.cond_br %21, ^bb1, ^bb3 + ^bb3: // pred: ^bb2 + %100 = tt.get_program_id x : i32 + %101 = arith.muli %100, %arg3 : i32 + %102 = tt.addptr %arg1, %101 : !tt.ptr, i32 + %103 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %104 = tt.splat %102 : !tt.ptr -> tensor<1024x!tt.ptr> + %105 = tt.addptr %104, %103 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %108 = tt.load %105 : tensor<1024x!tt.ptr> + %1017 = math.exp %108 : tensor<1024xbf16> + %1018 = arith.muli %100, %arg3 : i32 + %1019 = tt.addptr %arg0, %1018 : !tt.ptr, i32 + %1020 = tt.splat %1019 : !tt.ptr -> tensor<1024x!tt.ptr> + %1021 = tt.addptr %1020, %103 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %1021, %1017 : tensor<1024x!tt.ptr> + tt.return + } +} + +//CHECK-LABEL: @grouped_gemm_triton_kernel +//CHECK-NOT: cf.cond_br +//CHECK: %[[COND0:.*]] = arith.cmpi sge, %[[VAL0:.*]], %[[VAL1:.*]] : i64 +//CHECK: scf.if %[[COND0]] { +//CHECK: } else { +//CHECK: %[[COND1:.*]] = arith.cmpi eq, %[[VAL2:.*]], %[[VAL3:.*]] : i64 +//CHECK: scf.if %[[COND1]] { +//CHECK: } else { +//CHECK: } +//CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir new file mode 100644 index 000000000..991cb89d9 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir @@ -0,0 +1,77 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// RUN: triton-adapter-opt --triton-to-linalg="named-ops=True" %s | FileCheck %s --check-prefix NAMED +module { + tt.func @kernel( + %a : !tt.ptr, + %b : !tt.ptr, + %c : tensor<1024x!tt.ptr> + ) -> () { + %cst = arith.constant dense : tensor<1024xi1> + // offset calculations + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // a pointer + %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // b pointer + %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> + %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %am = tt.load %9 : tensor<1024x!tt.ptr> + %bm = tt.load %19 : tensor<1024x!tt.ptr> + %1 = arith.addf %am, %bm : tensor<1024xf32> + %2 = arith.subf %1, %bm : tensor<1024xf32> + %3 = arith.mulf %2, %bm : tensor<1024xf32> + %4 = arith.divf %3, %bm : tensor<1024xf32> + %5 = arith.cmpf "oeq", %4, %bm : tensor<1024xf32> + %6 = arith.select %5, %am, %bm : tensor<1024xi1>, tensor<1024xf32> + tt.store %c, %6 : tensor<1024x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_2:.*]]: memref<1024xf32>, +// CHECK-SAME: %[[ARG_5:.*]]: i32, [[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_6:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1]>> +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1]>> +// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<1024xf32> +// CHECK: memref.copy %[[VAL_6]], %[[VAL_8]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> +// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<1024xf32> +// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<1024xf32> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_10]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> +// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<1024xf32> +// CHECK: %[[VAL_12:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_9]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_9]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: linalg.yield %[[VAL_16]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL_17:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_18:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_18]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): +// CHECK: %[[VAL_22:.*]] = arith.subf %[[VAL_19]], %[[VAL_20]] : f32 +// CHECK: linalg.yield %[[VAL_22]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL_23:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_24:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_24]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): +// CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_25]], %[[VAL_26]] : f32 +// CHECK: linalg.yield %[[VAL_28]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL_29:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_30:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_30]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_31:.*]]: f32, %[[VAL_32:.*]]: f32, %[[VAL_33:.*]]: f32): +// CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_31]], %[[VAL_32]] : f32 +// CHECK: linalg.yield %[[VAL_34]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL_35:.*]] = tensor.empty() : tensor<1024xi1> +// CHECK: %[[VAL_36:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_37:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_35]] : tensor<1024xi1>) { +// CHECK: ^bb0(%[[VAL_38:.*]]: f32, %[[VAL_39:.*]]: f32, %[[VAL_40:.*]]: i1): +// CHECK: %[[VAL_41:.*]] = arith.cmpf oeq, %[[VAL_38]], %[[VAL_39]] : f32 +// CHECK: linalg.yield %[[VAL_41]] : i1 +// CHECK: } -> tensor<1024xi1> +// CHECK: %[[VAL_42:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_43:.*]], %[[VAL_9]], %[[VAL_11]] : tensor<1024xi1>, tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_9]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_44:.*]]: i1, %[[VAL_45:.*]]: f32, %[[VAL_46:.*]]: f32, %[[VAL_47:.*]]: f32): +// CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_44]], %[[VAL_45]], %[[VAL_46]] : f32 +// CHECK: linalg.yield %[[VAL_48]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: bufferization.materialize_in_destination %[[VAL_49:.*]] in writable %[[VAL_2]] +// CHECK: return +// CHECK: } + +// NAMED-LABEL: func.func @kernel +// NAMED-NOT: linalg.generic diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir new file mode 100644 index 000000000..887d73763 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir @@ -0,0 +1,54 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// RUN: triton-adapter-opt --triton-to-linalg="named-ops=True" %s | FileCheck %s --check-prefix NAMED +module { + tt.func @kernel( + %a : !tt.ptr, + %b : !tt.ptr, + %c : !tt.ptr, + %d : tensor<1024x!tt.ptr> + ) -> () { + // offset calculations + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // a pointer + %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // b pointer + %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> + %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // c pointer + %28 = tt.splat %c : !tt.ptr -> tensor<1024x!tt.ptr> + %29 = tt.addptr %28, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %am = tt.load %9 : tensor<1024x!tt.ptr> + %bm = tt.load %19 : tensor<1024x!tt.ptr> + %cm = tt.load %29 : tensor<1024x!tt.ptr> + %10 = arith.select %am, %bm, %cm : tensor<1024xi1>, tensor<1024xf32> + tt.store %d, %10 : tensor<1024x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_2:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_3:.*]]: memref<1024xf32>, +// CHECK-SAME: %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xi1, strided<[1]>> +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1]>> +// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1]>> +// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<1024xi1> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_10]] : memref<1024xi1, strided<[1]>> to memref<1024xi1> +// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<1024xi1> +// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<1024xf32> +// CHECK: memref.copy %[[VAL_8]], %[[VAL_12]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> +// CHECK: %[[VAL_13:.*]] = bufferization.to_tensor %[[VAL_12]] restrict writable : memref<1024xf32> +// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<1024xf32> +// CHECK: memref.copy %[[VAL_9]], %[[VAL_14]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> +// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<1024xf32> +// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<1024xi1>, tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_13]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_17:.*]]: i1, %[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32): +// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : f32 +// CHECK: linalg.yield %[[VAL_21]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: bufferization.materialize_in_destination %[[VAL_22:.*]] in writable %[[VAL_3]] +// CHECK: return +// CHECK: } + +// NAMED-LABEL: func.func @kernel +// NAMED-NOT: linalg.generic diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir new file mode 100644 index 000000000..6f1eac4ed --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir @@ -0,0 +1,92 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// RUN: triton-adapter-opt --triton-to-linalg="named-ops=True" %s | FileCheck %s --check-prefix NAMED +module { + tt.func @kernel( + %f32ptr : !tt.ptr, + %intptr : !tt.ptr, + %f16ptr : !tt.ptr, + %save0 : tensor<1024x!tt.ptr>, + %save1 : tensor<1024x!tt.ptr>, + %save2 : tensor<1024x!tt.ptr>, + %save3 : tensor<1024x!tt.ptr>, + %save4 : tensor<1024x!tt.ptr> + ) -> () { + // offset calculations + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // f32ptr pointer + %8 = tt.splat %f32ptr : !tt.ptr -> tensor<1024x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // intptr pointer + %18 = tt.splat %intptr : !tt.ptr -> tensor<1024x!tt.ptr> + %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // f32ptr pointer + %28 = tt.splat %f16ptr : !tt.ptr -> tensor<1024x!tt.ptr> + %29 = tt.addptr %28, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %afm = tt.load %9 : tensor<1024x!tt.ptr> + %aim = tt.load %19 : tensor<1024x!tt.ptr> + %bfm = tt.load %29 : tensor<1024x!tt.ptr> + %5 = arith.truncf %afm : tensor<1024xf32> to tensor<1024xbf16> + %6 = math.exp %afm : tensor<1024xf32> + %7 = arith.sitofp %aim : tensor<1024xi32> to tensor<1024xf32> + %10 = arith.extf %bfm : tensor<1024xf16> to tensor<1024xf32> + %11 = math.sqrt %afm : tensor<1024xf32> + tt.store %save0, %5 : tensor<1024x!tt.ptr> + tt.store %save1, %6 : tensor<1024x!tt.ptr> + tt.store %save2, %7 : tensor<1024x!tt.ptr> + tt.store %save3, %10 : tensor<1024x!tt.ptr> + tt.store %save4, %11 : tensor<1024x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_2:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_3:.*]]: memref<1024xbf16>, +// CHECK-SAME: %[[VAL_4:.*]]: memref<1024xf32>, %[[VAL_5:.*]]: memref<1024xf32>, %[[VAL_6:.*]]: memref<1024xf32>, %[[VAL_7:.*]]: memref<1024xf32>, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_11:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xf32, strided<[1]>> +// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xi32, strided<[1]>> +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [1024], strides: [1] : memref to memref<1024xf16, strided<[1]>> +// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<1024xf32> +// CHECK: memref.copy %[[VAL_11]], %[[VAL_14]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> +// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<1024xf32> +// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<1024xi32> +// CHECK: memref.copy %[[VAL_12]], %[[VAL_16]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> +// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<1024xi32> +// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<1024xf16> +// CHECK: memref.copy %[[VAL_13]], %[[VAL_18]] : memref<1024xf16, strided<[1]>> to memref<1024xf16> +// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<1024xf16> +// CHECK: %[[VAL_20:.*]] = tensor.empty() : tensor<1024xbf16> +// CHECK: %[[VAL_21:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_15]] : tensor<1024xf32>) outs(%[[VAL_20]] : tensor<1024xbf16>) { +// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: bf16): +// CHECK: %[[VAL_24:.*]] = arith.truncf %[[VAL_22]] : f32 to bf16 +// CHECK: linalg.yield %[[VAL_24]] : bf16 +// CHECK: } -> tensor<1024xbf16> +// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_15]] : tensor<1024xf32>) outs(%[[VAL_15]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): +// CHECK: %[[VAL_28:.*]] = math.exp %[[VAL_26]] : f32 +// CHECK: linalg.yield %[[VAL_28]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL_29:.*]] = tensor.empty() : tensor<1024xf32> +// CHECK: %[[VAL_30:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_17]] : tensor<1024xi32>) outs(%[[VAL_29]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_31:.*]]: i32, %[[VAL_32:.*]]: f32): +// CHECK: %[[VAL_33:.*]] = arith.sitofp %[[VAL_31]] : i32 to f32 +// CHECK: linalg.yield %[[VAL_33]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL_35:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_19]] : tensor<1024xf16>) outs(%[[VAL_29]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_36:.*]]: f16, %[[VAL_37:.*]]: f32): +// CHECK: %[[VAL_38:.*]] = arith.extf %[[VAL_36]] : f16 to f32 +// CHECK: linalg.yield %[[VAL_38]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL_39:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_15]] : tensor<1024xf32>) outs(%[[VAL_15]] : tensor<1024xf32>) { +// CHECK: ^bb0(%[[VAL_40:.*]]: f32, %[[VAL_41:.*]]: f32): +// CHECK: %[[VAL_42:.*]] = math.sqrt %[[VAL_40]] : f32 +// CHECK: linalg.yield %[[VAL_42]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: bufferization.materialize_in_destination %[[VAL_43:.*]] in writable %[[VAL_3]] +// CHECK: bufferization.materialize_in_destination %[[VAL_44:.*]] in writable %[[VAL_4]] +// CHECK: bufferization.materialize_in_destination %[[VAL_45:.*]] in writable %[[VAL_5]] +// CHECK: bufferization.materialize_in_destination %[[VAL_46:.*]] in writable %[[VAL_6]] +// CHECK: bufferization.materialize_in_destination %[[VAL_47:.*]] in writable %[[VAL_7]] +// CHECK: return +// CHECK: } + +// NAMED-LABEL: func.func @kernel +// NAMED-NOT: linalg.generic diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir new file mode 100644 index 000000000..c437ec657 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir @@ -0,0 +1,60 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// RUN: triton-adapter-opt --triton-to-linalg="named-ops=True" %s | FileCheck %s --check-prefix NAMED +module { + tt.func @kernel( + %a : !tt.ptr, + %b : !tt.ptr, + %c : tensor<128x128x!tt.ptr>, + %d : tensor<128x128x!tt.ptr> + ) -> () { + // offset calculations + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %moff = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %koff = tt.broadcast %4 : tensor<1x128xi32> -> tensor<128x128xi32> + %mkoff = arith.addi %moff, %koff : tensor<128x128xi32> + // a pointer + %8 = tt.splat %a : !tt.ptr -> tensor<128x128x!tt.ptr> + %9 = tt.addptr %8, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // b pointer + %18 = tt.splat %b : !tt.ptr -> tensor<128x128x!tt.ptr> + %19 = tt.addptr %18, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + %af = tt.load %9 : tensor<128x128x!tt.ptr> + %bf = tt.load %19 : tensor<128x128x!tt.ptr> + %res0 = arith.addf %af, %bf : tensor<128x128xf32> + %res1 = arith.subf %af, %bf : tensor<128x128xf32> + tt.store %c, %res0 : tensor<128x128x!tt.ptr> + tt.store %d, %res1 : tensor<128x128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_2:.*]]: memref<128x128xf32>, %[[VAL_3:.*]]: memref<128x128xf32>, +// CHECK-SAME: %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xf32, strided<[1, 1]>> +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xf32, strided<[1, 1]>> +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<128x128xf32> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> +// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<128x128xf32> +// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<128x128xf32> +// CHECK: memref.copy %[[VAL_8]], %[[VAL_11]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> +// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<128x128xf32> +// CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_10]], %[[VAL_12]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_10]] : tensor<128x128xf32>) { +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32, %[[VAL_16:.*]]: f32): +// CHECK: %[[VAL_17:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: linalg.yield %[[VAL_17]] : f32 +// CHECK: } -> tensor<128x128xf32> +// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_10]], %[[VAL_12]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_10]] : tensor<128x128xf32>) { +// CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): +// CHECK: %[[VAL_22:.*]] = arith.subf %[[VAL_19]], %[[VAL_20]] : f32 +// CHECK: linalg.yield %[[VAL_22]] : f32 +// CHECK: } -> tensor<128x128xf32> +// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_2]] +// CHECK: bufferization.materialize_in_destination %[[VAL_24:.*]] in writable %[[VAL_3]] +// CHECK: return +// CHECK: } + +// NAMED-LABEL: func.func @kernel +// NAMED-NOT: linalg.generic diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir new file mode 100644 index 000000000..817f8b03d --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir @@ -0,0 +1,60 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// RUN: triton-adapter-opt --triton-to-linalg="named-ops=True" %s | FileCheck %s --check-prefix NAMED +module { + tt.func @kernel( + %a : !tt.ptr, + %b : !tt.ptr, + %c : !tt.ptr, + %d : tensor<128x128x!tt.ptr> + ) -> () { + // offset calculations + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %moff = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %koff = tt.broadcast %4 : tensor<1x128xi32> -> tensor<128x128xi32> + %mkoff = arith.addi %moff, %koff : tensor<128x128xi32> + // a pointer + %8 = tt.splat %a : !tt.ptr -> tensor<128x128x!tt.ptr> + %9 = tt.addptr %8, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // b pointer + %18 = tt.splat %b : !tt.ptr -> tensor<128x128x!tt.ptr> + %19 = tt.addptr %18, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // c pointer + %28 = tt.splat %c : !tt.ptr -> tensor<128x128x!tt.ptr> + %29 = tt.addptr %28, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + %am = tt.load %9 : tensor<128x128x!tt.ptr> + %bm = tt.load %19 : tensor<128x128x!tt.ptr> + %cm = tt.load %29 : tensor<128x128x!tt.ptr> + %100 = arith.select %am, %bm, %cm : tensor<128x128xi1>, tensor<128x128xf32> + tt.store %d, %100 : tensor<128x128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_2:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_3:.*]]: memref<128x128xf32>, +// CHECK-SAME: %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xi1, strided<[1, 1]>> +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xf32, strided<[1, 1]>> +// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xf32, strided<[1, 1]>> +// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<128x128xi1> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_10]] : memref<128x128xi1, strided<[1, 1]>> to memref<128x128xi1> +// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<128x128xi1> +// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<128x128xf32> +// CHECK: memref.copy %[[VAL_8]], %[[VAL_12]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> +// CHECK: %[[VAL_13:.*]] = bufferization.to_tensor %[[VAL_12]] restrict writable : memref<128x128xf32> +// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<128x128xf32> +// CHECK: memref.copy %[[VAL_9]], %[[VAL_14]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> +// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<128x128xf32> +// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<128x128xi1>, tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_13]] : tensor<128x128xf32>) { +// CHECK: ^bb0(%[[VAL_17:.*]]: i1, %[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32): +// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : f32 +// CHECK: linalg.yield %[[VAL_21]] : f32 +// CHECK: } -> tensor<128x128xf32> +// CHECK: bufferization.materialize_in_destination %[[VAL_22:.*]] in writable %[[VAL_3]] +// CHECK: return +// CHECK: } + +// NAMED-LABEL: func.func @kernel +// NAMED-NOT: linalg.generic diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir new file mode 100644 index 000000000..aa1336824 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir @@ -0,0 +1,98 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// RUN: triton-adapter-opt --triton-to-linalg="named-ops=True" %s | FileCheck %s --check-prefix NAMED +module { + tt.func @kernel( + %f32ptr : !tt.ptr, + %intptr : !tt.ptr, + %f16ptr : !tt.ptr, + %save0 : tensor<128x128x!tt.ptr>, + %save1 : tensor<128x128x!tt.ptr>, + %save2 : tensor<128x128x!tt.ptr>, + %save3 : tensor<128x128x!tt.ptr>, + %save4 : tensor<128x128x!tt.ptr> + ) -> () { + // offset calculations + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %moff = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %koff = tt.broadcast %4 : tensor<1x128xi32> -> tensor<128x128xi32> + %mkoff = arith.addi %moff, %koff : tensor<128x128xi32> + // f32ptr pointer + %8 = tt.splat %f32ptr : !tt.ptr -> tensor<128x128x!tt.ptr> + %9 = tt.addptr %8, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // intptr pointer + %18 = tt.splat %intptr : !tt.ptr -> tensor<128x128x!tt.ptr> + %19 = tt.addptr %18, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // f16ptr pointer + %28 = tt.splat %f16ptr : !tt.ptr -> tensor<128x128x!tt.ptr> + %29 = tt.addptr %28, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + %afm = tt.load %9 : tensor<128x128x!tt.ptr> + %aim = tt.load %19 : tensor<128x128x!tt.ptr> + %bfm = tt.load %29 : tensor<128x128x!tt.ptr> + %5 = arith.truncf %afm : tensor<128x128xf32> to tensor<128x128xbf16> + %6 = math.exp %afm : tensor<128x128xf32> + %7 = arith.sitofp %aim : tensor<128x128xi32> to tensor<128x128xf32> + %10 = arith.extf %bfm : tensor<128x128xf16> to tensor<128x128xf32> + %11 = math.sqrt %afm : tensor<128x128xf32> + tt.store %save0, %5 : tensor<128x128x!tt.ptr> + tt.store %save1, %6 : tensor<128x128x!tt.ptr> + tt.store %save2, %7 : tensor<128x128x!tt.ptr> + tt.store %save3, %10 : tensor<128x128x!tt.ptr> + tt.store %save4, %11 : tensor<128x128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_2:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_3:.*]]: memref<128x128xbf16>, +// CHECK-SAME: %[[VAL_4:.*]]: memref<128x128xf32>, %[[VAL_5:.*]]: memref<128x128xf32>, %[[VAL_6:.*]]: memref<128x128xf32>, %[[VAL_7:.*]]: memref<128x128xf32>, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_11:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xf32, strided<[1, 1]>> +// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xi32, strided<[1, 1]>> +// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref to memref<128x128xf16, strided<[1, 1]>> +// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<128x128xf32> +// CHECK: memref.copy %[[VAL_11]], %[[VAL_14]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> +// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<128x128xf32> +// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<128x128xi32> +// CHECK: memref.copy %[[VAL_12]], %[[VAL_16]] : memref<128x128xi32, strided<[1, 1]>> to memref<128x128xi32> +// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<128x128xi32> +// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<128x128xf16> +// CHECK: memref.copy %[[VAL_13]], %[[VAL_18]] : memref<128x128xf16, strided<[1, 1]>> to memref<128x128xf16> +// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<128x128xf16> +// CHECK: %[[VAL_20:.*]] = tensor.empty() : tensor<128x128xbf16> +// CHECK: %[[VAL_21:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_15]] : tensor<128x128xf32>) outs(%[[VAL_20]] : tensor<128x128xbf16>) { +// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: bf16): +// CHECK: %[[VAL_24:.*]] = arith.truncf %[[VAL_22]] : f32 to bf16 +// CHECK: linalg.yield %[[VAL_24]] : bf16 +// CHECK: } -> tensor<128x128xbf16> +// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_15]] : tensor<128x128xf32>) outs(%[[VAL_15]] : tensor<128x128xf32>) { +// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): +// CHECK: %[[VAL_28:.*]] = math.exp %[[VAL_26]] : f32 +// CHECK: linalg.yield %[[VAL_28]] : f32 +// CHECK: } -> tensor<128x128xf32> +// CHECK: %[[VAL_29:.*]] = tensor.empty() : tensor<128x128xf32> +// CHECK: %[[VAL_30:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_17]] : tensor<128x128xi32>) outs(%[[VAL_29]] : tensor<128x128xf32>) { +// CHECK: ^bb0(%[[VAL_31:.*]]: i32, %[[VAL_32:.*]]: f32): +// CHECK: %[[VAL_33:.*]] = arith.sitofp %[[VAL_31]] : i32 to f32 +// CHECK: linalg.yield %[[VAL_33]] : f32 +// CHECK: } -> tensor<128x128xf32> +// CHECK: %[[VAL_35:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_19]] : tensor<128x128xf16>) outs(%[[VAL_29]] : tensor<128x128xf32>) { +// CHECK: ^bb0(%[[VAL_36:.*]]: f16, %[[VAL_37:.*]]: f32): +// CHECK: %[[VAL_38:.*]] = arith.extf %[[VAL_36]] : f16 to f32 +// CHECK: linalg.yield %[[VAL_38]] : f32 +// CHECK: } -> tensor<128x128xf32> +// CHECK: %[[VAL_39:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_15]] : tensor<128x128xf32>) outs(%[[VAL_15]] : tensor<128x128xf32>) { +// CHECK: ^bb0(%[[VAL_40:.*]]: f32, %[[VAL_41:.*]]: f32): +// CHECK: %[[VAL_42:.*]] = math.sqrt %[[VAL_40]] : f32 +// CHECK: linalg.yield %[[VAL_42]] : f32 +// CHECK: } -> tensor<128x128xf32> +// CHECK: bufferization.materialize_in_destination %[[VAL_43:.*]] in writable %[[VAL_3]] +// CHECK: bufferization.materialize_in_destination %[[VAL_44:.*]] in writable %[[VAL_4]] +// CHECK: bufferization.materialize_in_destination %[[VAL_45:.*]] in writable %[[VAL_5]] +// CHECK: bufferization.materialize_in_destination %[[VAL_46:.*]] in writable %[[VAL_6]] +// CHECK: bufferization.materialize_in_destination %[[VAL_47:.*]] in writable %[[VAL_7]] +// CHECK: return +// CHECK: } + +// NAMED-LABEL: func.func @kernel +// NAMED-NOT: linalg.generic diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_addi_reduce.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_addi_reduce.mlir new file mode 100644 index 000000000..414c316e5 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_addi_reduce.mlir @@ -0,0 +1,22 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s + +module { + tt.func public @addi(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0> : tensor<4096xi32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: i32, %arg15: i32): + %69 = arith.addi %arg14, %arg15 : i32 + tt.reduce.return %69 : i32 + }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 + tt.store %arg0, %63 : !tt.ptr + tt.return + } +} + + +// CHECK: %[[VAL_1:.*]] = tensor.extract %reduced[] : tensor +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = linalg.fill ins(%[[VAL_1]] : i32) outs(%[[VAL_2]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_3]] in writable %[[VAL_4]] : (tensor<1xi32>, memref<1xi32, strided<[1]>>) -> () +// CHECK: return diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir new file mode 100644 index 000000000..64bdd5443 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir @@ -0,0 +1,144 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s +module { + tt.func public @argmax_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> + %3 = tt.splat %1 : i32 -> tensor<4096xi32> + %4 = arith.addi %3, %2 : tensor<4096xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> + %7 = tt.load %6 : tensor<4096x!tt.ptr> + %8:2 = "tt.reduce"(%7, %2) <{axis = 0 : i32}> ({ + ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): + %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + %12 = arith.cmpi slt, %arg10, %arg12 : i32 + %13 = arith.andi %11, %12 : i1 + %14 = arith.cmpf ogt, %arg9, %arg11 : f32 + %15 = arith.ori %14, %13 : i1 + %16 = arith.select %15, %arg9, %arg11 : f32 + %17 = arith.select %15, %arg10, %arg12 : i32 + tt.reduce.return %16, %17 : f32, i32 + }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) + %9 = tt.addptr %arg1, %0 : !tt.ptr, i32 + tt.store %9, %8#1 : !tt.ptr + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @argmax_012 +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_2_:%.+]]: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli %arg8, [[PARAM_2_]] : i32 +// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32> +// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) { +// CHECK: ^bb0([[out_:.+]]: i32): +// CHECK: [[VAR_10_:%.+]] = linalg.index 0 : index +// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i32 +// CHECK: linalg.yield [[VAR_11_]] : i32 +// CHECK: } -> tensor<4096xi32> +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4096], strides: [1] : memref to memref<4096xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4096xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4096xf32, strided<[1], offset: ?>> to memref<4096xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4096xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_5_]] : tensor) -> tensor +// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor +// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_7_]] : tensor) -> tensor +// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xf32>, tensor<4096xi32>) outs([[VAR_6_]], [[VAR_8_]] : tensor, tensor) dimensions = [0] +// CHECK: ([[in:.+]]: f32, [[in_1:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { +// CHECK-DAG: [[VAR_10_1_:%.+]] = arith.cmpf ogt, [[in]], [[init]] : f32 +// CHECK-DAG: [[VAR_11_1_:%.+]] = arith.cmpf oeq, [[in]], [[init]] : f32 +// CHECK-DAG: [[VAR_12_1_:%.+]] = arith.cmpi slt, [[in_1]], [[init_2]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_13_:%.+]] = arith.andi [[VAR_11_1_]], [[VAR_12_1_]] : i1 +// CHECK: [[VAR_14_:%.+]] = arith.ori [[VAR_10_1_]], [[VAR_13_]] : i1 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.select [[VAR_14_]], [[in]], [[init]] : f32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[in_1]], [[init_2]] : i32 +// CHECK: linalg.yield [[VAR_15_]], [[VAR_16_]] : f32, i32 +// CHECK: } +// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor +// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast %arg8 : i32 to index +// CHECK: [[VAR_17_:%.+]] = tensor.empty() : tensor<1xi32> +// CHECK: [[VAR_18_:%.+]] = linalg.fill ins([[VAR_extracted_]] : i32) outs([[VAR_17_]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_9_]]{{.}}, sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_18_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<1xi32>, memref<1xi32, strided<[1], offset: ?>>) -> () +// CHECK: return +// CHECK: } + +// ----- + +module { + tt.func public @argmin_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> + %3 = tt.splat %1 : i32 -> tensor<4096xi32> + %4 = arith.addi %3, %2 : tensor<4096xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> + %7 = tt.load %6 : tensor<4096x!tt.ptr> + %8:2 = "tt.reduce"(%7, %2) <{axis = 0 : i32}> ({ + ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): + %11 = arith.cmpf oeq, %arg9, %arg11 : f32 + %12 = arith.cmpi slt, %arg10, %arg12 : i32 + %13 = arith.andi %11, %12 : i1 + %14 = arith.cmpf olt, %arg9, %arg11 : f32 + %15 = arith.ori %14, %13 : i1 + %16 = arith.select %15, %arg9, %arg11 : f32 + %17 = arith.select %15, %arg10, %arg12 : i32 + tt.reduce.return %16, %17 : f32, i32 + }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) + %9 = tt.addptr %arg1, %0 : !tt.ptr, i32 + tt.store %9, %8#1 : !tt.ptr + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @argmin_012 +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_2_:%.+]]: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli %arg8, [[PARAM_2_]] : i32 +// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32> +// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) { +// CHECK: ^bb0([[out_:.+]]: i32): +// CHECK: [[VAR_10_:%.+]] = linalg.index 0 : index +// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i32 +// CHECK: linalg.yield [[VAR_11_]] : i32 +// CHECK: } -> tensor<4096xi32> +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4096], strides: [1] : memref to memref<4096xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4096xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4096xf32, strided<[1], offset: ?>> to memref<4096xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4096xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_5_]] : tensor) -> tensor +// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor +// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_7_]] : tensor) -> tensor +// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xf32>, tensor<4096xi32>) outs([[VAR_6_]], [[VAR_8_]] : tensor, tensor) dimensions = [0] +// CHECK: ([[in:.+]]: f32, [[in_1:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { +// CHECK-DAG: [[VAR_10_1_:%.+]] = arith.cmpf olt, [[in]], [[init]] : f32 +// CHECK-DAG: [[VAR_11_1_:%.+]] = arith.cmpf oeq, [[in]], [[init]] : f32 +// CHECK-DAG: [[VAR_12_1_:%.+]] = arith.cmpi slt, [[in_1]], [[init_2]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_13_:%.+]] = arith.andi [[VAR_11_1_]], [[VAR_12_1_]] : i1 +// CHECK: [[VAR_14_:%.+]] = arith.ori [[VAR_10_1_]], [[VAR_13_]] : i1 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.select [[VAR_14_]], [[in]], [[init]] : f32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[in_1]], [[init_2]] : i32 +// CHECK: linalg.yield [[VAR_15_]], [[VAR_16_]] : f32, i32 +// CHECK: } +// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor +// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast %arg8 : i32 to index +// CHECK: [[VAR_17_:%.+]] = tensor.empty() : tensor<1xi32> +// CHECK: [[VAR_18_:%.+]] = linalg.fill ins([[VAR_extracted_]] : i32) outs([[VAR_17_]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_9_]]{{.}}, sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_18_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<1xi32>, memref<1xi32, strided<[1], offset: ?>>) -> () +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir new file mode 100644 index 000000000..7a4f8dee3 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir @@ -0,0 +1,199 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s + +// @triton.jit +// def test( +// a_ptr, c_ptr, stride_am, stride_an +// ): +// offs_am = tl.arange(0, 4) +// offs_an = tl.arange(0, 4) +// a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) +// a = tl.load(a_ptrs) +// m = tl.argmax(a, axis=1) +// tl.store(c_ptr + tl.arange(0, 4), m) +// +// ret = triton.compiler.compile( +// test, +// signature=" *fp32,*fp32,i32,i32", +// print_triton_ir_only=True, +// ) + +module { + tt.func public @test_argmax(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %2 = tt.splat %arg2 : i32 -> tensor<4x1xi32> + %3 = arith.muli %1, %2 : tensor<4x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1x4xi32> + %6 = arith.muli %4, %5 : tensor<1x4xi32> + %7 = tt.broadcast %3 : tensor<4x1xi32> -> tensor<4x4xi32> + %8 = tt.broadcast %6 : tensor<1x4xi32> -> tensor<4x4xi32> + %9 = arith.addi %7, %8 : tensor<4x4xi32> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> + %11 = tt.addptr %10, %9 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %12 = tt.load %11 : tensor<4x4x!tt.ptr> + %13 = tt.broadcast %4 : tensor<1x4xi32> -> tensor<4x4xi32> + %14:2 = "tt.reduce"(%12, %13) <{axis = 1 : i32}> ({ + ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): + %18 = arith.cmpf oeq, %arg4, %arg6 : f32 + %19 = arith.cmpi slt, %arg5, %arg7 : i32 + %20 = arith.andi %18, %19 : i1 + %21 = arith.cmpf ogt, %arg4, %arg6 : f32 + %22 = arith.ori %21, %20 : i1 + %23 = arith.select %22, %arg4, %arg6 : f32 + %24 = arith.select %22, %arg5, %arg7 : i32 + tt.reduce.return %23, %24 : f32, i32 + }) : (tensor<4x4xf32>, tensor<4x4xi32>) -> (tensor<4xf32>, tensor<4xi32>) + %15 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %16 = tt.addptr %15, %0 : tensor<4x!tt.ptr>, tensor<4xi32> + %17 = arith.sitofp %14#1 : tensor<4xi32> to tensor<4xf32> + tt.store %16, %17 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @test_argmax +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> +// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { +// CHECK: ^bb0([[out_:.+]]: i32): +// CHECK: [[VAR_13_:%.+]] = linalg.index 0 : index +// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i32 +// CHECK: linalg.yield [[VAR_14_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 4], strides: {{.}}[[VAR_2_]], [[VAR_3_]]{{.}} : memref to memref<4x4xf32, strided<[?, ?]>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x4xf32, strided<[?, ?]>> to memref<4x4xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor<4x4xi32> +// CHECK: [[VAR_6_:%.+]] = linalg.broadcast ins([[VAR_1_]] : tensor<4xi32>) outs([[VAR_5_]] : tensor<4x4xi32>) dimensions = [0] +// CHECK: [[VAR_7_:%.+]] = tensor.empty() : tensor<4xf32> +// CHECK-DAG: [[VAR_8_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_7_]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK: [[VAR_10_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_6_]] : tensor<4x4xf32>, tensor<4x4xi32>) outs([[VAR_8_]], [[VAR_10_]] : tensor<4xf32>, tensor<4xi32>) dimensions = [1] +// CHECK: ([[in_:.+]]: f32, [[in_1_:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { +// CHECK-DAG: [[VAR_13_:%.+]] = arith.cmpf ogt, [[in_]], [[init]] : f32 +// CHECK-DAG: [[VAR_14_1_:%.+]] = arith.cmpf oeq, [[in_]], [[init]] : f32 +// CHECK-DAG: [[VAR_15_1_:%.+]] = arith.cmpi slt, [[in_1_]], [[init_2]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.andi [[VAR_14_1_]], [[VAR_15_1_]] : i1 +// CHECK: [[VAR_17_:%.+]] = arith.ori [[VAR_13_]], [[VAR_16_]] : i1 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[in_]], [[init]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[in_1_]], [[init_2]] : i32 +// CHECK: linalg.yield [[VAR_18_]], [[VAR_19_]] : f32, i32 +// CHECK: } +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf32, strided<[1]>> +// CHECK: [[VAR_12_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_reduced_]]#1 : tensor<4xi32>) outs([[VAR_7_]] : tensor<4xf32>) { +// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: f32): +// CHECK: [[VAR_13_2_:%.+]] = arith.sitofp [[in_]] : i32 to f32 +// CHECK: linalg.yield [[VAR_13_2_]] : f32 +// CHECK: } -> tensor<4xf32> +// CHECK: bufferization.materialize_in_destination [[VAR_12_]] in writable [[VAR_reinterpret_cast_0_]] +// CHECK: return +// CHECK: } + +// ----- + +// @triton.jit +// def test( +// a_ptr, c_ptr, stride_am, stride_an +// ): +// offs_am = tl.arange(0, 4) +// offs_an = tl.arange(0, 4) +// a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) +// a = tl.load(a_ptrs) +// m = tl.argmin(a, axis=1) +// tl.store(c_ptr + tl.arange(0, 4), m) +// +// ret = triton.compiler.compile( +// test, +// signature=" *fp32,*fp32,i32,i32", +// print_triton_ir_only=True, +// ) + +module { + tt.func public @test_argmin(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %2 = tt.splat %arg2 : i32 -> tensor<4x1xi32> + %3 = arith.muli %1, %2 : tensor<4x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1x4xi32> + %6 = arith.muli %4, %5 : tensor<1x4xi32> + %7 = tt.broadcast %3 : tensor<4x1xi32> -> tensor<4x4xi32> + %8 = tt.broadcast %6 : tensor<1x4xi32> -> tensor<4x4xi32> + %9 = arith.addi %7, %8 : tensor<4x4xi32> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> + %11 = tt.addptr %10, %9 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %12 = tt.load %11 : tensor<4x4x!tt.ptr> + %13 = tt.broadcast %4 : tensor<1x4xi32> -> tensor<4x4xi32> + %14:2 = "tt.reduce"(%12, %13) <{axis = 1 : i32}> ({ + ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): + %18 = arith.cmpf oeq, %arg4, %arg6 : f32 + %19 = arith.cmpi slt, %arg5, %arg7 : i32 + %20 = arith.andi %18, %19 : i1 + %21 = arith.cmpf olt, %arg4, %arg6 : f32 + %22 = arith.ori %21, %20 : i1 + %23 = arith.select %22, %arg4, %arg6 : f32 + %24 = arith.select %22, %arg5, %arg7 : i32 + tt.reduce.return %23, %24 : f32, i32 + }) : (tensor<4x4xf32>, tensor<4x4xi32>) -> (tensor<4xf32>, tensor<4xi32>) + %15 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %16 = tt.addptr %15, %0 : tensor<4x!tt.ptr>, tensor<4xi32> + %17 = arith.sitofp %14#1 : tensor<4xi32> to tensor<4xf32> + tt.store %16, %17 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @test_argmin +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> +// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { +// CHECK: ^bb0([[out_:.+]]: i32): +// CHECK: [[VAR_13_:%.+]] = linalg.index 0 : index +// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i32 +// CHECK: linalg.yield [[VAR_14_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 4], strides: {{.}}[[VAR_2_]], [[VAR_3_]]{{.}} : memref to memref<4x4xf32, strided<[?, ?]>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x4xf32, strided<[?, ?]>> to memref<4x4xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor<4x4xi32> +// CHECK: [[VAR_6_:%.+]] = linalg.broadcast ins([[VAR_1_]] : tensor<4xi32>) outs([[VAR_5_]] : tensor<4x4xi32>) dimensions = [0] +// CHECK: [[VAR_7_:%.+]] = tensor.empty() : tensor<4xf32> +// CHECK-DAG: [[VAR_8_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_7_]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK: [[VAR_10_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_6_]] : tensor<4x4xf32>, tensor<4x4xi32>) outs([[VAR_8_]], [[VAR_10_]] : tensor<4xf32>, tensor<4xi32>) dimensions = [1] +// CHECK: ([[in_:.+]]: f32, [[in_1_:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { +// CHECK-DAG: [[VAR_13_:%.+]] = arith.cmpf olt, [[in_]], [[init]] : f32 +// CHECK-DAG: [[VAR_14_1_:%.+]] = arith.cmpf oeq, [[in_]], [[init]] : f32 +// CHECK-DAG: [[VAR_15_1_:%.+]] = arith.cmpi slt, [[in_1_]], [[init_2]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.andi [[VAR_14_1_]], [[VAR_15_1_]] : i1 +// CHECK: [[VAR_17_:%.+]] = arith.ori [[VAR_13_]], [[VAR_16_]] : i1 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[in_]], [[init]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[in_1_]], [[init_2]] : i32 +// CHECK: linalg.yield [[VAR_18_]], [[VAR_19_]] : f32, i32 +// CHECK: } +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf32, strided<[1]>> +// CHECK: [[VAR_12_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_reduced_]]#1 : tensor<4xi32>) outs([[VAR_7_]] : tensor<4xf32>) { +// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: f32): +// CHECK: [[VAR_13_2_:%.+]] = arith.sitofp [[in_]] : i32 to f32 +// CHECK: linalg.yield [[VAR_13_2_]] : f32 +// CHECK: } -> tensor<4xf32> +// CHECK: bufferization.materialize_in_destination [[VAR_12_]] in writable [[VAR_reinterpret_cast_0_]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_cl_extern_elementwise.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_cl_extern_elementwise.mlir new file mode 100644 index 000000000..b636aa35b --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_cl_extern_elementwise.mlir @@ -0,0 +1,80 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s + +module { + tt.func public @fabs_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 evictionPolicy = evict_last : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__hmf_fabsf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 evictionPolicy = evict_last : tensor<32x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: func.func private @__hmf_fabsf(f32) -> f32 attributes {llvm.readnone} +// CHECK-LABEL: func.func @fabs_kernel_012 +// CHECK: [[RES:%.+]] = linalg.map { func.call {callee = @__hmf_fabsf} } ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]]: tensor<32xf32>) + +// ----- + +module { + tt.func public @sqrt_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 evictionPolicy = evict_last : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__hmf_sqrtf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 evictionPolicy = evict_last : tensor<32x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: func.func private @__hmf_sqrtf(f32) -> f32 attributes {llvm.readnone} +// CHECK-LABEL: func.func @sqrt_kernel_012 +// CHECK: [[RES:%.+]] = linalg.map { func.call {callee = @__hmf_sqrtf} } ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]]: tensor<32xf32>) + +// ----- + +module { + tt.func public @rsqrt_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 evictionPolicy = evict_last : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__hmf_rsqrtf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 evictionPolicy = evict_last : tensor<32x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: func.func private @__hmf_rsqrtf(f32) -> f32 attributes {llvm.readnone} +// CHECK-LABEL: func.func @rsqrt_kernel_012 +// CHECK: [[RES:%.+]] = linalg.map { func.call {callee = @__hmf_rsqrtf} } ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]]: tensor<32xf32>) diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir new file mode 100644 index 000000000..b9c013c1b --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir @@ -0,0 +1,709 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s + +module { + tt.func public @atan2_kernel_0123(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg3 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %12 = tt.load %11, %6 : tensor<32x!tt.ptr> + %13 = tt.extern_elementwise %9, %12 {libname = "", libpath = "", pure = true, symbol = "__nv_atan2f"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %15, %13 : tensor<32x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @atan2_kernel_0123 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]], [[VAR_2:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_atan2f"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + + +// ----- + +module { + tt.func public @pow_kernel_0123(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg3 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %12 = tt.load %11, %6 : tensor<32x!tt.ptr> + %13 = tt.extern_elementwise %9, %12 {libname = "", libpath = "", pure = true, symbol = "__nv_powf"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %15, %13 : tensor<32x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @pow_kernel_0123 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]], [[VAR_2:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_powf"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + + +// ----- + +module { + tt.func public @fabs_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_fabsf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @fabs_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_fabsf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @sin_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_sinf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @sin_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_sinf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @cos_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_cosf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @cos_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_cosf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @tan_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_tanf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @tan_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_tanf"} : (tensor<32xf32>) -> tensor<32xf32> + + +// ----- + +module { + tt.func public @asin_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_asinf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @asin_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_asinf"} : (tensor<32xf32>) -> tensor<32xf32> + + +// ----- + +module { + tt.func public @acos_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_acosf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @acos_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_acosf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @atan_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_atanf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @atan_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_atanf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @sinh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_sinhf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @sinh_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_sinhf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @cosh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_coshf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @cosh_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_coshf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @tanh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_tanhf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @tanh_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_tanhf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @asinh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_asinhf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @asinh_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_asinhf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @acosh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_acoshf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @acosh_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_acoshf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @atanh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_atanhf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @atanh_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_atanhf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @log_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_logf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @log_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_logf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @log10_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_log10f"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @log10_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_log10f"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @log1p_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @log1p_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @exp_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_expf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @exp_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_expf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @exp2_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_exp2f"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @exp2_kernel_012 +/// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_exp2f"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @erf_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_erff"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @erf_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_erff"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @sqrt_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_sqrtf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @sqrt_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_sqrtf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @rsqrt_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_rsqrtf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @rsqrt_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_rsqrtf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @ceil_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_ceilf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @ceil_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_ceilf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @floor_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_floorf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @floor_kernel_012 +/// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_floorf"} : (tensor<32xf32>) -> tensor<32xf32> + +// ----- + +module { + tt.func public @trunc_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.splat %arg2 : i32 -> tensor<32xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %9 = tt.load %8, %6 : tensor<32x!tt.ptr> + %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_truncf"} : (tensor<32xf32>) -> tensor<32xf32> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %12, %10 : tensor<32x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @trunc_kernel_012 +// CHECK: [[RES:%.+]] = tt.extern_elementwise [[VAR_1:%.+]] {libname = "", libpath = "", pure = true, symbol = "__nv_truncf"} : (tensor<32xf32>) -> tensor<32xf32> diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_minmax.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_minmax.mlir new file mode 100644 index 000000000..b8778713a --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_minmax.mlir @@ -0,0 +1,51 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s +module { + tt.func public @minmax_olt(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { + %0 = arith.cmpf olt, %arg1, %arg2 : f32 + %1 = arith.select %0, %arg1, %arg2 : f32 + tt.store %arg0, %1 : !tt.ptr + tt.return + } +} +// CHECK: func.func @minmax_olt +// CHECK: %[[VAL:.*]] = arith.cmpf olt, [[ARG_1:%.+]], [[ARG_2:%.+]] : f32 + + +// ----- + +module { + tt.func public @minmax_ole(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { + %0 = arith.cmpf ole, %arg1, %arg2 : f32 + %1 = arith.select %0, %arg1, %arg2 : f32 + tt.store %arg0, %1 : !tt.ptr + tt.return + } +} +// CHECK: func.func @minmax_ole +// CHECK: %[[VAL:.*]] = arith.cmpf ole, [[ARG_1:%.+]], [[ARG_2:%.+]] : f32 + +// ----- + +module { + tt.func public @minmax_ogt(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { + %0 = arith.cmpf ogt, %arg1, %arg2 : f32 + %1 = arith.select %0, %arg1, %arg2 : f32 + tt.store %arg0, %1 : !tt.ptr + tt.return + } +} +// CHECK: func.func @minmax_ogt +// CHECK: %[[VAL:.*]] = arith.cmpf ogt, [[ARG_1:%.+]], [[ARG_2:%.+]] : f32 + +// ----- + +module { + tt.func public @minmax_oge(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { + %0 = arith.cmpf oge, %arg1, %arg2 : f32 + %1 = arith.select %0, %arg1, %arg2 : f32 + tt.store %arg0, %1 : !tt.ptr + tt.return + } +} +// CHECK: func.func @minmax_oge +// CHECK: %[[VAL:.*]] = arith.cmpf oge, [[ARG_1:%.+]], [[ARG_2:%.+]] : f32 diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir new file mode 100644 index 000000000..35d4d997a --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir @@ -0,0 +1,70 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s + +module { + tt.func public @maxnumf(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<4096xf32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: f32, %arg15: f32): + %69 = arith.maxnumf %arg14, %arg15 : f32 + tt.reduce.return %69 : f32 + }) {axis = 0 : i32} : (tensor<4096xf32>) -> f32 + tt.store %arg0, %63 : !tt.ptr + tt.return + } +} + +// CHECK-LABEL: func.func @maxnumf +// CHECK: %[[CST:.*]] = arith.constant 0xFF800000 : f32 +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<4096xf32> +// CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[VAL_0]] : tensor<4096xf32>) -> tensor<4096xf32> +// CHECK: %[[VAL_2:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[VAL_3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_2]] : tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = linalg.reduce ins(%[[VAL_1]] : tensor<4096xf32>) outs(%[[VAL_3]] : tensor) dimensions = [0] +// CHECK: (%in: f32, %init: f32) { +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %in, %init : f32 +// CHECK: linalg.yield %[[VAL_5]] : f32 +// CHECK: } +// CHECK: %[[VAL_6:.*]] = tensor.extract %[[VAL_4]][] : tensor +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<1xf32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%extracted : f32) outs(%[[VAL_7]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK-DAG: %[[VAL_9:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xf32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_8]] in writable %[[VAL_9]] : (tensor<1xf32>, memref<1xf32, strided<[1]>>) -> () +// CHECK: return + + +// ----- + + +module { + tt.func public @minnumf(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<4096xf32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: f32, %arg15: f32): + %69 = arith.minnumf %arg14, %arg15 : f32 + tt.reduce.return %69 : f32 + }) {axis = 0 : i32} : (tensor<4096xf32>) -> f32 + tt.store %arg0, %63 : !tt.ptr + tt.return + } +} + +// CHECK-LABEL: func.func @minnumf +// CHECK: %[[CST:.*]] = arith.constant 0x7F800000 : f32 +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<4096xf32> +// CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[VAL_0]] : tensor<4096xf32>) -> tensor<4096xf32> +// CHECK: %[[VAL_2:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[VAL_3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_2]] : tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = linalg.reduce ins(%[[VAL_1]] : tensor<4096xf32>) outs(%[[VAL_3]] : tensor) dimensions = [0] +// CHECK: (%in: f32, %init: f32) { +// CHECK: %[[VAL_5:.*]] = arith.minnumf %in, %init : f32 +// CHECK: linalg.yield %[[VAL_5]] : f32 +// CHECK: } +// CHECK: %[[VAL_6:.*]] = tensor.extract %[[VAL_4]][] : tensor +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<1xf32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%extracted : f32) outs(%[[VAL_7]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xf32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_8]] in writable %[[VAL_9]] : (tensor<1xf32>, memref<1xf32, strided<[1]>>) -> () +// CHECK: return + diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir new file mode 100644 index 000000000..4a6cbfb3e --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir @@ -0,0 +1,135 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s +module { + tt.func public @minmax_sgt(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0> : tensor<4096xi32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: i32, %arg15: i32): + %69 = arith.cmpi sgt, %arg14, %arg15 : i32 + %70 = arith.select %69, %arg14, %arg15 : i32 + tt.reduce.return %70 : i32 + }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 + tt.store %arg0, %63 : !tt.ptr + tt.return + } +} + +// CHECK: func.func @minmax_sgt +// CHECK: %[[VAL_c0:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> +// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor +// CHECK: %[[VAL_10:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_9]] : tensor) dimensions = [0] {reduce_mode = "max_with_index"} +// CHECK: (%in: i32, %init: i32) { +// CHECK: %[[VAL_11:.*]] = arith.cmpi sgt, %in, %init : i32 +// CHECK: %[[VAL_12:.*]] = arith.select %[[VAL_11]], %in, %init : i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } +// CHECK: %[[VAL_13:.*]] = tensor.extract %[[VAL_10]][] : tensor +// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = linalg.fill ins(%[[VAL_13]] : i32) outs(%[[VAL_14]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: %reinterpret_cast = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %reinterpret_cast : (tensor<1xi32>, memref<1xi32, strided<[1]>>) -> () +// CHECK: return + + +// ----- + +module { + tt.func public @minmax_ugt(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0> : tensor<4096xi32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: i32, %arg15: i32): + %69 = arith.cmpi ugt, %arg14, %arg15 : i32 + %70 = arith.select %69, %arg14, %arg15 : i32 + tt.reduce.return %70 : i32 + }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 + tt.store %arg0, %63 : !tt.ptr + tt.return + } +} + +// CHECK: func.func @minmax_ugt +// CHECK: %[[VAL_c0:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> +// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor +// CHECK: %[[VAL_10:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_9]] : tensor) dimensions = [0] {reduce_mode = "max_with_index"} +// CHECK: (%in: i32, %init: i32) { +// CHECK: %[[VAL_11:.*]] = arith.cmpi ugt, %in, %init : i32 +// CHECK: %[[VAL_12:.*]] = arith.select %[[VAL_11]], %in, %init : i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } +// CHECK: %[[VAL_13:.*]] = tensor.extract %[[VAL_10]][] : tensor +// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = linalg.fill ins(%[[VAL_13]] : i32) outs(%[[VAL_14]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: %reinterpret_cast = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %reinterpret_cast : (tensor<1xi32>, memref<1xi32, strided<[1]>>) -> () +// CHECK: return + +// ----- + +module { + tt.func public @minmax_slt(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0> : tensor<4096xi32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: i32, %arg15: i32): + %69 = arith.cmpi slt, %arg14, %arg15 : i32 + %70 = arith.select %69, %arg14, %arg15 : i32 + tt.reduce.return %70 : i32 + }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 + tt.store %arg0, %63 : !tt.ptr + tt.return + } +} + +// CHECK: func.func @minmax_slt +// CHECK: %[[VAL_c0:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> +// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor +// CHECK: %[[VAL_10:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_9]] : tensor) dimensions = [0] {reduce_mode = "min_with_index"} +// CHECK: (%in: i32, %init: i32) { +// CHECK: %[[VAL_11:.*]] = arith.cmpi slt, %in, %init : i32 +// CHECK: %[[VAL_12:.*]] = arith.select %[[VAL_11]], %in, %init : i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } +// CHECK: %[[VAL_13:.*]] = tensor.extract %[[VAL_10]][] : tensor +// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = linalg.fill ins(%[[VAL_13]] : i32) outs(%[[VAL_14]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: %reinterpret_cast = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %reinterpret_cast : (tensor<1xi32>, memref<1xi32, strided<[1]>>) -> () +// CHECK: return + +// ----- + +module { + tt.func public @minmax_ult(%arg0: !tt.ptr) { + %cst_0 = arith.constant dense<0> : tensor<4096xi32> + %63 = "tt.reduce"(%cst_0) ({ + ^bb0(%arg14: i32, %arg15: i32): + %69 = arith.cmpi ult, %arg14, %arg15 : i32 + %70 = arith.select %69, %arg14, %arg15 : i32 + tt.reduce.return %70 : i32 + }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 + tt.store %arg0, %63 : !tt.ptr + tt.return + } +} + +// CHECK: func.func @minmax_ult +// CHECK: %[[VAL_c0:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> +// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor +// CHECK: %[[VAL_10:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_9]] : tensor) dimensions = [0] {reduce_mode = "min_with_index"} +// CHECK: (%in: i32, %init: i32) { +// CHECK: %[[VAL_11:.*]] = arith.cmpi ult, %in, %init : i32 +// CHECK: %[[VAL_12:.*]] = arith.select %[[VAL_11]], %in, %init : i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } +// CHECK: %[[VAL_13:.*]] = tensor.extract %[[VAL_10]][] : tensor +// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = linalg.fill ins(%[[VAL_13]] : i32) outs(%[[VAL_14]] : tensor<1xi32>) -> tensor<1xi32> +// CHECK: %reinterpret_cast = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %reinterpret_cast : (tensor<1xi32>, memref<1xi32, strided<[1]>>) -> () +// CHECK: return diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_select.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_select.mlir new file mode 100644 index 000000000..8c974f2a0 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_select.mlir @@ -0,0 +1,65 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) -> () { + %cst = arith.constant dense<0.000000e+00> : tensor<512xf32> + %cst_0 = arith.constant dense<256> : tensor<512xi64> + %cst_1 = arith.constant dense<512> : tensor<512xi32> + %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> + %1 = arith.cmpi slt, %0, %cst_1 : tensor<512xi32> + %2 = arith.extsi %0 : tensor<512xi32> to tensor<512xi64> + %3 = arith.cmpi slt, %2, %cst_0 : tensor<512xi64> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<512x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<512x!tt.ptr>, tensor<512xi32> + %6 = tt.load %5, %1, %cst evictionPolicy = evict_last : tensor<512x!tt.ptr> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<512x!tt.ptr> + %8 = tt.addptr %7, %0 : tensor<512x!tt.ptr>, tensor<512xi32> + %9 = tt.load %8, %1, %cst evictionPolicy = evict_last : tensor<512x!tt.ptr> + %10 = arith.select %3, %6, %9 : tensor<512xi1>, tensor<512xf32> + %11 = tt.splat %arg2 : !tt.ptr -> tensor<512x!tt.ptr> + %12 = tt.addptr %11, %0 : tensor<512x!tt.ptr>, tensor<512xi32> + tt.store %12, %10, %1 evictionPolicy = evict_last : tensor<512x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[ARG_2:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[ARG_3:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[ARG_4:.*]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: [[c256_i64:%.+]] = arith.constant 256 : i64 +// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<512xi64> +// CHECK: %[[VAL_1:.*]] = linalg.fill ins([[c256_i64]] : i64) outs(%[[VAL_0]] : tensor<512xi64>) -> tensor<512xi64> +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<512xi32> +// CHECK: %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_2]] : tensor<512xi32>) { +// CHECK: ^bb0(%out: i32): +// CHECK: %[[VAL_10:.*]] = linalg.index 0 : index +// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : index to i32 +// CHECK: linalg.yield %[[VAL_11]] : i32 +// CHECK: } -> tensor<512xi32> +// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_3]] : tensor<512xi32>) outs(%[[VAL_0]] : tensor<512xi64>) { +// CHECK: ^bb0(%in: i32, %out: i64): +// CHECK: %[[VAL_10:.*]] = arith.extsi %in : i32 to i64 +// CHECK: linalg.yield %[[VAL_10]] : i64 +// CHECK: } -> tensor<512xi64> +// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<512xi1> +// CHECK: %[[VAL_6:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_4]], %[[VAL_1]] : tensor<512xi64>, tensor<512xi64>) outs(%[[VAL_5]] : tensor<512xi1>) { +// CHECK: ^bb0(%in: i64, %in_3: i64, %out: i1): +// CHECK: %[[VAL_10:.*]] = arith.cmpi slt, %in, %in_3 : i64 +// CHECK: linalg.yield %[[VAL_10]] : i1 +// CHECK: } -> tensor<512xi1> +// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [0], sizes: [512], strides: [1] : memref to memref<512xf32, strided<[1]>> +// CHECK: %alloc = memref.alloc() : memref<512xf32> +// CHECK: memref.copy %reinterpret_cast, %alloc : memref<512xf32, strided<[1]>> to memref<512xf32> +// CHECK: %[[VAL_7:.*]] = bufferization.to_tensor %alloc restrict writable : memref<512xf32> +// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg3 to offset: [0], sizes: [512], strides: [1] : memref to memref<512xf32, strided<[1]>> +// CHECK: %alloc_1 = memref.alloc() : memref<512xf32> +// CHECK: memref.copy %reinterpret_cast_0, %alloc_1 : memref<512xf32, strided<[1]>> to memref<512xf32> +// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %alloc_1 restrict writable : memref<512xf32> +// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : tensor<512xi1>, tensor<512xf32>, tensor<512xf32>) outs(%[[VAL_7]] : tensor<512xf32>) { +// CHECK: ^bb0(%in: i1, %in_3: f32, %in_4: f32, %out: f32): +// CHECK: %[[VAL_10:.*]] = arith.select %in, %in_3, %in_4 : f32 +// CHECK: linalg.yield %[[VAL_10]] : f32 +// CHECK: } -> tensor<512xf32> +// CHECK: %reinterpret_cast_2 = memref.reinterpret_cast %arg4 to offset: [0], sizes: [512], strides: [1] : memref to memref<512xf32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in writable %reinterpret_cast_2 : (tensor<512xf32>, memref<512xf32, strided<[1]>>) -> () +// CHECK: return +// CHECK: } \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_splat_float.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_splat_float.mlir new file mode 100644 index 000000000..a6befefb2 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_splat_float.mlir @@ -0,0 +1,23 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel(%fin : f32, + %bin : bf16, + %save0 : tensor<1024x!tt.ptr>, + %save1 : tensor<128x256x!tt.ptr>) -> () { + %0 = tt.splat %fin : f32 -> tensor<1024xf32> + %1 = tt.splat %bin : bf16 -> tensor<128x256xbf16> + tt.store %save0, %0 : tensor<1024x!tt.ptr> + tt.store %save1, %1 : tensor<128x256x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: bf16, %[[VAL_2:.*]]: memref<1024xf32>, %[[VAL_3:.*]]: memref<128x256xbf16>, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<1024xf32> +// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_0]] : f32) outs(%[[VAL_7]] : tensor<1024xf32>) -> tensor<1024xf32> +// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<128x256xbf16> +// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_1]] : bf16) outs(%[[VAL_9]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_8]] in writable %[[VAL_2]] +// CHECK: bufferization.materialize_in_destination %[[VAL_10]] in writable %[[VAL_3]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir new file mode 100644 index 000000000..0de15dedb --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir @@ -0,0 +1,45 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func public @bcast_kernel_01(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %3 = tt.splat %1 : i32 -> tensor<32xi32> + %4 = arith.addi %3, %2 : tensor<32xi32> + %5 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32> + %6 = tt.splat %1 : i32 -> tensor<2048xi32> + %7 = arith.addi %6, %5 : tensor<2048xi32> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %9 = tt.addptr %8, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %10 = tt.load %9 : tensor<32x!tt.ptr> + %11 = tt.reshape %10 : tensor<32xf32> -> tensor<1x32xf32> + %12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32> + %13 = tt.reshape %12 : tensor<64x32xf32> -> tensor<2048xf32> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<2048x!tt.ptr> + %15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr>, tensor<2048xi32> + tt.store %15, %13 : tensor<2048x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @bcast_kernel_01 +// CHECK: %[[C2048_I64:.*]] = arith.constant dense<2048> : tensor<1xi64> +// CHECK: %[[CST:.*]] = arith.constant dense<[1, 32]> : tensor<2xi64> +// CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32 +// CHECK: %[[VAR_0:.*]] = arith.muli [[PID_0:%.+]], %[[C32_I32]] : i32 +// CHECK: %[[VAR_1:.*]] = arith.index_cast %[[VAR_0]] : i32 to index +// CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast [[ARG_0:%.+]] to offset: [%[[VAR_1]]], sizes: [32], strides: [1] : memref to memref<32xf32, strided<[1], offset: ?>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32> +// CHECK: memref.copy %[[REINTERPRET_CAST:.*]], %[[ALLOC]] : memref<32xf32, strided<[1], offset: ?>> to memref<32xf32> +// CHECK: %[[VAR_2:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<32xf32> +// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[VAR_2]](%[[CST]]) : (tensor<32xf32>, tensor<2xi64>) -> tensor<1x32xf32> +// CHECK: %[[VAR_3:.*]] = tensor.empty() : tensor<64x32xf32> +// CHECK: %[[VAR_4:.*]] = tensor.collapse_shape %[[RESHAPE]] {{\[}}[0, 1]] : tensor<1x32xf32> into tensor<32xf32> +// CHECK: %[[VAR_5:.*]] = linalg.broadcast ins(%[[VAR_4]] : tensor<32xf32>) outs(%[[VAR_3]] : tensor<64x32xf32>) dimensions = [0] +// CHECK: %[[RESHAPE_0:.*]] = tensor.reshape %[[VAR_5]](%[[C2048_I64]]) : (tensor<64x32xf32>, tensor<1xi64>) -> tensor<2048xf32> + +// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast [[ARG_1:%.+]] to offset: [%[[VAR_1]]], sizes: [2048], strides: [1] : memref to memref<2048xf32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[RESHAPE_0]] in writable %[[REINTERPRET_CAST_1]] : (tensor<2048xf32>, memref<2048xf32, strided<[1], offset: ?>>) -> () +// CHECK: return diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/cumsum.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/cumsum.mlir new file mode 100644 index 000000000..afabbba78 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/cumsum.mlir @@ -0,0 +1,67 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// XFAIL: * +// @triton.jit +// def test_cumsum_op( +// input_ptr, output_ptr, n_columns +// ): +// row = tl.program_id(axis=0) +// row_start = row * n_columns +// columns = tl.arange(0, 4096) +// offsets = row_start + columns +// data = tl.load(input_ptr + offsets) +// result = tl.cumsum(data, axis=0) +// tl.store(output_ptr + offsets, result) +// +// ret = triton.compiler.compile( +// test_cumsum_op, +// signature=" *fp32,*i32,i32", +// print_triton_ir_only=True, +// ) +// print(ret.asm["ttir"]) + +module { + tt.func public @test_cumsum_op_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> + %3 = tt.splat %1 : i32 -> tensor<4096xi32> + %4 = arith.addi %3, %2 : tensor<4096xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> + %7 = tt.load %6 : tensor<4096x!tt.ptr> + %8 = "tt.scan"(%7) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg3: f32, %arg4: f32): + %12 = arith.addf %arg3, %arg4 : f32 + tt.scan.return %12 : f32 + }) : (tensor<4096xf32>) -> tensor<4096xf32> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<4096x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> + %11 = arith.fptosi %8 : tensor<4096xf32> to tensor<4096xi32> + tt.store %10, %11 : tensor<4096x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @test_cumsum_op_012 +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: [[PARAM_2_:%.+]]: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: [[FALSE:%.+]] = arith.constant false +// CHECK-DAG: [[C0_i32:%.+]] = arith.constant 0 : i32 +// CHECK: [[VAR_0_:%.+]] = arith.muli %arg8, [[PARAM_2_]] : i32 +// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [4096], strides: [1] : memref to memref<4096xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4096xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4096xf32, strided<[1], offset: ?>> to memref<4096xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4096xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = call @triton_cumsum([[VAR_2_]], [[C0_i32]], %false) : (tensor<4096xf32>, i32, i1) -> tensor<4096xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [4096], strides: [1] : memref to memref<4096xi32, strided<[1], offset: ?>> +// CHECK-DAG: [[VAR_6_:%.+]] = tensor.empty() : tensor<4096xi32> +// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]] : tensor<4096xf32>) outs([[VAR_6_]] : tensor<4096xi32>) { +// CHECK: ^bb0([[in_:.+]]: f32, [[out_:.+]]: i32): +// CHECK: [[VAR_8_:%.+]] = arith.fptosi [[in_]] : f32 to i32 +// CHECK: linalg.yield [[VAR_8_]] : i32 +// CHECK: } -> tensor<4096xi32> +// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_0_]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/debug.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/debug.mlir new file mode 100644 index 000000000..75bc60028 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/debug.mlir @@ -0,0 +1,18 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + // CHECK: func.func private @triton_print_0(i32) attributes {hex = false, prefix = " pid =: "} + // CHECK-NEXT: func.func private @triton_print_1(tensor<1024xf32>) attributes {hex = true, prefix = " Val =: "} + // CHECK: func.func @test_print + tt.func public @test_print(%arg0: i32, %arg1: !tt.ptr) { + %0 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %2 = tt.addptr %0, %1 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %3 = tt.load %2 : tensor<1024x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " pid =: " {hex = false, isSigned = array} : %arg0 : i32 + %4 = arith.addf %3, %3 : tensor<1024xf32> + // CHECK: call @triton_print_1 + tt.print " Val =: " {hex = true, isSigned = array} : %3 : tensor<1024xf32> + tt.return + } +} \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/dot.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/dot.mlir new file mode 100644 index 000000000..d65a9dc1a --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/dot.mlir @@ -0,0 +1,73 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr + ) + { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %c64 = arith.constant 128 : i32 + %1 = tt.splat %c64 : i32 -> tensor<128xi32> + %2 = arith.muli %0, %1 : tensor<128xi32> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %4 = tt.broadcast %3 : tensor<128x1xi32> -> tensor<128x64xi32> + %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %7 = tt.broadcast %6 : tensor<1x64xi32> -> tensor<128x64xi32> + %8 = arith.addi %4, %7 : tensor<128x64xi32> + %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> + %12 = tt.broadcast %11 : tensor<256x1xi32> -> tensor<256x64xi32> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %c256 = arith.constant 256 : i32 + %14 = tt.splat %c256 : i32 -> tensor<64xi32> + %15 = arith.muli %13, %14 : tensor<64xi32> + %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %17 = tt.broadcast %16 : tensor<1x64xi32> -> tensor<256x64xi32> + %18 = arith.addi %12, %17 : tensor<256x64xi32> + %20 = tt.splat %c256 : i32 -> tensor<128xi32> + %21 = arith.muli %0, %20 : tensor<128xi32> + %22 = tt.expand_dims %21 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %23 = tt.broadcast %22 : tensor<128x1xi32> -> tensor<128x256xi32> + %24 = tt.expand_dims %10 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %25 = tt.broadcast %24 {axis = 0 : i32} : tensor<1x256xi32> -> tensor<128x256xi32> + %26 = arith.addi %23, %25 : tensor<128x256xi32> + %30 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> + %31 = tt.addptr %30, %8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x64x!tt.ptr> + %40 = tt.splat %arg1 : !tt.ptr -> tensor<256x64x!tt.ptr> + %41 = tt.addptr %40, %18 : tensor<256x64x!tt.ptr>, tensor<256x64xi32> + %42 = tt.load %41 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x64x!tt.ptr> + %43 = tt.trans %42 {order = array} : tensor<256x64xbf16> -> tensor<64x256xbf16> + %50 = tt.splat %arg2 : !tt.ptr -> tensor<128x256x!tt.ptr> + %51 = tt.addptr %50, %26 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %52 = tt.load %51 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x256x!tt.ptr> + %60 = tt.dot %32, %43, %52 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> + tt.store %51, %60 : tensor<128x256x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @kernel +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_2_:%.+]]: memref {tt.tensor_kind = 2 : i32}, +// CHECK-SAME: %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "mix"} { +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 64], strides: [128, 1] : memref to memref<128x64xbf16, strided<[128, 1]>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<128x64xbf16, strided<[128, 1]>> to memref<128x64xbf16> +// CHECK-DAG: [[VAR_0_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [256, 64], strides: [1, 256] : memref to memref<256x64xbf16, strided<[1, 256]>> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<256x64xbf16> +// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<256x64xbf16, strided<[1, 256]>> to memref<256x64xbf16> +// CHECK-DAG: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256x64xbf16> +// CHECK-DAG: annotation.mark [[VAR_1_]] {MayImplicitTransposeWithLastAxis} : tensor<256x64xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = tensor.empty() : tensor<64x256xbf16> +// CHECK-DAG: [[VAR_transposed_:%.+]] = linalg.transpose ins([[VAR_1_]] : tensor<256x64xbf16>) outs([[VAR_2_]] : tensor<64x256xbf16>) permutation = [1, 0] +// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 256], strides: [256, 1] : memref to memref<128x256xbf16, strided<[256, 1]>> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<128x256xbf16> +// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_2_]] : memref<128x256xbf16, strided<[256, 1]>> to memref<128x256xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<128x256xbf16> +// CHECK: [[VAR_4_:%.+]] = linalg.matmul {input_precison = "ieee"} ins([[VAR_0_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_3_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK: bufferization.materialize_in_destination [[VAR_4_]] in writable [[VAR_reinterpret_cast_2_]] : (tensor<128x256xbf16>, memref<128x256xbf16, strided<[256, 1]>>) -> () +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/get_num_programs.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/get_num_programs.mlir new file mode 100644 index 000000000..d56d00ebf --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/get_num_programs.mlir @@ -0,0 +1,41 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @num_programs(%arg0: !tt.ptr) { + %0 = tt.get_num_programs x : i32 + %1 = tt.get_num_programs y : i32 + %2 = tt.get_num_programs z : i32 + %3 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> + %4 = tt.make_range {end = 2 : i32, start = 1 : i32} : tensor<1xi32> + %5 = tt.make_range {end = 3 : i32, start = 2 : i32} : tensor<1xi32> + %6 = tt.splat %arg0 : !tt.ptr -> tensor<1x!tt.ptr> + %7 = tt.addptr %6, %3 : tensor<1x!tt.ptr>, tensor<1xi32> + %8 = tt.splat %0 : i32 -> tensor<1xi32> + tt.store %7, %8 : tensor<1x!tt.ptr> + %9 = tt.addptr %6, %4 : tensor<1x!tt.ptr>, tensor<1xi32> + %10 = tt.splat %1 : i32 -> tensor<1xi32> + tt.store %9, %10 : tensor<1x!tt.ptr> + %11 = tt.addptr %6, %5 : tensor<1x!tt.ptr>, tensor<1xi32> + %12 = tt.splat %2 : i32 -> tensor<1xi32> + tt.store %11, %12 : tensor<1x!tt.ptr> + tt.return + } +} + + +// CHECK-LABEL: func.func @num_programs +// CHECK-SAME: (%arg0: memref, %arg1: memref, %arg2: memref {tt.tensor_kind = 1 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { + +// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [0], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1]>> +// CHECK: %0 = tensor.empty() : tensor<1xi32> +// CHECK: %1 = linalg.fill ins(%arg3 : i32) outs(%0 : tensor<1xi32>) -> tensor<1xi32> +// CHECK: bufferization.materialize_in_destination %1 in writable %reinterpret_cast : (tensor<1xi32>, memref<1xi32, strided<[1]>>) -> () +// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg2 to offset: [1], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1], offset: 1>> +// CHECK: %2 = linalg.fill ins(%arg4 : i32) outs(%0 : tensor<1xi32>) -> tensor<1xi32> +// CHECK: bufferization.materialize_in_destination %2 in writable %reinterpret_cast_0 : (tensor<1xi32>, memref<1xi32, strided<[1], offset: 1>>) -> () +// CHECK: %reinterpret_cast_1 = memref.reinterpret_cast %arg2 to offset: [2], sizes: [1], strides: [1] : memref to memref<1xi32, strided<[1], offset: 2>> +// CHECK: %3 = linalg.fill ins(%arg5 : i32) outs(%0 : tensor<1xi32>) -> tensor<1xi32> +// CHECK: bufferization.materialize_in_destination %3 in writable %reinterpret_cast_1 : (tensor<1xi32>, memref<1xi32, strided<[1], offset: 2>>) -> ( +// CHECK: return +// CHECK: } + diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/interleave.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/interleave.mlir new file mode 100644 index 000000000..3f8f0634a --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/interleave.mlir @@ -0,0 +1,62 @@ +// RUN: triton-adapter-opt --triton-to-linalg -split-input-file %s | FileCheck %s + +module { + tt.func public @load_deinterleave(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst_1 = arith.constant dense<1> : tensor<32xi32> + %cst_2 = arith.constant dense<2> : tensor<32xi32> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + // Pay attention here: `multiply 2` tells that last dimension stride is 2 + %1 = arith.muli %0, %cst_2 : tensor<32xi32> + %2 = arith.addi %1, %cst_1 : tensor<32xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %4 = tt.addptr %3, %1 : tensor<32x!tt.ptr>, tensor<32xi32> + // even index + // CHECK: %[[VAL_0:.*]] = bufferization.to_tensor + // CHECK: tensor.extract_slice %[[VAL_0]][0] [32] [2] : tensor<64xf16> to tensor<32xf16> + %5 = tt.load %4 : tensor<32x!tt.ptr> + %6 = tt.addptr %3, %2 : tensor<32x!tt.ptr>, tensor<32xi32> + // odd index + // CHECK: %[[VAL_1:.*]] = bufferization.to_tensor + // CHECK: tensor.extract_slice %[[VAL_1]][1] [32] [2] : tensor<64xf16> to tensor<32xf16> + %7 = tt.load %6 : tensor<32x!tt.ptr> + %8 = tt.make_range {end = 64 : i32, start = 32 : i32} : tensor<32xi32> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %10 = tt.addptr %9, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %10, %5 : tensor<32x!tt.ptr> + %11 = tt.addptr %9, %8 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %11, %7 : tensor<32x!tt.ptr> + tt.return + } +} + + +// ----- + +module { + tt.func public @store_interleave(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %cst_1 = arith.constant dense<1> : tensor<32xi32> + %cst_2 = arith.constant dense<2> : tensor<32xi32> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.make_range {end = 64 : i32, start = 32 : i32} : tensor<32xi32> + %5 = tt.addptr %1, %4 : tensor<32x!tt.ptr>, tensor<32xi32> + %6 = tt.load %5 : tensor<32x!tt.ptr> + // Pay attention here: `multiply 2` tells that last dimension stride is 2 + %7 = arith.muli %0, %cst_2 : tensor<32xi32> + %8 = arith.addi %7, %cst_1 : tensor<32xi32> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %10 = tt.addptr %9, %7 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %10, %3 : tensor<32x!tt.ptr> + %11 = tt.addptr %9, %8 : tensor<32x!tt.ptr>, tensor<32xi32> + // CHECK: %[[LOAD_FIRST:.*]] = bufferization.to_tensor + // CHECK: %[[LOAD_SECOND:.*]] = bufferization.to_tensor + // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<64xf16> + // CHECK: %[[INSERT_FIRST:.*]] = tensor.insert_slice %[[LOAD_FIRST]] into %[[EMPTY]][0] [32] [2] : tensor<32xf16> into tensor<64xf16> + // CHECK: %[[INSERT_SECOND:.*]] = tensor.insert_slice %[[LOAD_SECOND]] into %[[INSERT_FIRST]][1] [32] [2] : tensor<32xf16> into tensor<64xf16> + // CHECK: bufferization.materialize_in_destination %[[INSERT_SECOND]] + tt.store %11, %6 : tensor<32x!tt.ptr> + tt.return + } +} \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/join.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/join.mlir new file mode 100644 index 000000000..d9d88196c --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/join.mlir @@ -0,0 +1,48 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func public @fn_triton_join(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<2> : tensor<1x8x1xi32> + %cst_0 = arith.constant dense<2> : tensor<8x1x1xi32> + %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> + %cst_2 = arith.constant dense<8> : tensor<8x1xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %2 = arith.muli %1, %cst_2 : tensor<8x1xi32> + %3 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %4 = tt.broadcast %2 : tensor<8x1xi32> -> tensor<8x8xi32> + %5 = tt.broadcast %3 : tensor<1x8xi32> -> tensor<8x8xi32> + %6 = arith.addi %4, %5 : tensor<8x8xi32> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<8x8x!tt.ptr>, tensor<8x8xi32> + %9 = tt.load %8 : tensor<8x8x!tt.ptr> + %10 = tt.splat %arg2 : !tt.ptr -> tensor<8x8x!tt.ptr> + %11 = tt.addptr %10, %6 : tensor<8x8x!tt.ptr>, tensor<8x8xi32> + %12 = tt.load %11 : tensor<8x8x!tt.ptr> + %13 = tt.join %9, %12 : tensor<8x8xf32> -> tensor<8x8x2xf32> + %14 = tt.expand_dims %1 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %15 = arith.muli %14, %cst_1 : tensor<8x1x1xi32> + %16 = arith.muli %15, %cst_0 : tensor<8x1x1xi32> + %17 = tt.expand_dims %3 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %18 = arith.muli %17, %cst : tensor<1x8x1xi32> + %19 = tt.broadcast %16 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %20 = tt.broadcast %18 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %21 = arith.addi %19, %20 : tensor<8x8x1xi32> + %22 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %23 = tt.expand_dims %22 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> + %24 = tt.expand_dims %23 {axis = 1 : i32} : tensor<1x2xi32> -> tensor<1x1x2xi32> + %25 = tt.broadcast %21 : tensor<8x8x1xi32> -> tensor<8x8x2xi32> + %26 = tt.broadcast %24 : tensor<1x1x2xi32> -> tensor<8x8x2xi32> + %27 = arith.addi %25, %26 : tensor<8x8x2xi32> + %28 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x2x!tt.ptr> + %29 = tt.addptr %28, %27 : tensor<8x8x2x!tt.ptr>, tensor<8x8x2xi32> + tt.store %29, %13 : tensor<8x8x2x!tt.ptr> + tt.return + } +} +//CHECK-LABEL: @fn_triton_join +//CHECK-NOT: tt.join +//CHECK: %[[IN0:.*]] = bufferization.to_tensor %[[ADDR0:.*]] restrict writable : memref<8x8xf32> +//CHECK: %[[IN1:.*]] = bufferization.to_tensor %[[ADDR1:.*]] restrict writable : memref<8x8xf32> +//CHECK: %[[ZERO:.*]] = tensor.empty() : tensor<8x8x2xf32> +//CHECK: %[[INSERT0:.*]] = tensor.insert_slice %[[IN0]] into %[[ZERO]][0, 0, 0] [8, 8, 1] [1, 1, 2] : tensor<8x8xf32> into tensor<8x8x2xf32> +//CHECK: %[[INSERT1:.*]] = tensor.insert_slice %[[IN1]] into %[[INSERT0]][0, 0, 1] [8, 8, 1] [1, 1, 2] : tensor<8x8xf32> into tensor<8x8x2xf32> \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir new file mode 100644 index 000000000..6edd4ba52 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir @@ -0,0 +1,65 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr> + %13 = arith.addf %9, %12 : tensor<1024xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %13, %6 : tensor<1024x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @add_kernel_01234( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_2_:%.+]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: [[PARAM_3_:%.+]]: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index +// CHECK-DAG: [[CST_1024_1_:%.+]] = arith.constant 1024 : i32 +// CHECK: [[VAR_0_:%.+]] = arith.muli %arg9, [[CST_1024_1_]] : i32 +// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [1024], strides: [1]{{.*}} : memref to memref<1024xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[VAR_1_]], [[CST_1024_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_5_0_:%.+]] = arith.maxsi [[VAR_1_]], [[VAR_4_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.minsi [[VAR_3_]], [[VAR_5_0_]] : index +// CHECK: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1]{{.*}} : memref<1024xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<1024xf32> to memref> +// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_0 : memref> to memref> +// CHECK-DAG: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<1024xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [1024], strides: [1]{{.*}} : memref to memref<1024xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<1024xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] [[[VAR_6_]]] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_1_]][0] [[[VAR_6_]]] [1] : memref<1024xf32> to memref> +// CHECK: memref.copy [[VAR_subview_3_]], [[VAR_subview_4_]] : memref> to memref> +// CHECK: [[VAR_14_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<1024xf32> +// CHECK: [[VAR_15_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_14_]] : tensor<1024xf32>, tensor<1024xf32>) outs([[VAR_7_]] : tensor<1024xf32>) { +// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out:%.+]]: f32): +// CHECK: [[VAR_22_:%.+]] = arith.addf [[in_1]], [[in_2]] : f32 +// CHECK: linalg.yield [[VAR_22_]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [1024], strides: [1]{{.*}} : memref to memref<1024xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_15_]][0] [[[VAR_6_]]] [1] : tensor<1024xf32> to tensor +// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] [[[VAR_6_]]] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> +// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_6_]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir new file mode 100644 index 000000000..b8e776463 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir @@ -0,0 +1,105 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @softmax_kernel_012345(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32) { + %cst = arith.constant 0xFF800000 : f32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %4 = tt.splat %2 : !tt.ptr -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + %6 = tt.splat %arg4 : i32 -> tensor<128xi32> + %7 = arith.cmpi slt, %3, %6 : tensor<128xi32> + %8 = tt.splat %cst : f32 -> tensor<128xf32> + %9 = tt.load %5, %7, %8 : tensor<128x!tt.ptr> + %10 = "tt.reduce"(%9) ({ + ^bb0(%arg5: f32, %arg6: f32): + %21 = arith.cmpf ogt, %arg5, %arg6 : f32 + %22 = arith.select %21, %arg5, %arg6 : f32 + tt.reduce.return %22 : f32 + }) {axis = 0 : i32} : (tensor<128xf32>) -> f32 + %11 = tt.splat %10 : f32 -> tensor<128xf32> + %12 = arith.subf %9, %11 : tensor<128xf32> + %13 = math.exp %12 : tensor<128xf32> + %14 = "tt.reduce"(%13) ({ + ^bb0(%arg5: f32, %arg6: f32): + %21 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %21 : f32 + }) {axis = 0 : i32} : (tensor<128xf32>) -> f32 + %15 = tt.splat %14 : f32 -> tensor<128xf32> + %16 = arith.divf %13, %15 : tensor<128xf32> + %17 = arith.muli %0, %arg3 : i32 + %18 = tt.addptr %arg0, %17 : !tt.ptr, i32 + %19 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr> + %20 = tt.addptr %19, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %20, %16, %7 : tensor<128x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @softmax_kernel_012345 +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 0 : i32}, +// CHECK-SAME: [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: [[CST:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_0_index_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli %arg10, [[PARAM_2_]] : i32 +// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [128], strides: [1]{{.*}} : memref to memref<128xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_3_0_:%.+]] = arith.maxsi [[VAR_2_]], [[CST_0_index_]] : index +// CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_3_0_]], [[CST_128_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.cmpi slt, [[VAR_3_]], [[CST_128_]] : index +// CHECK: scf.if [[VAR_4_]] { +// CHECK: linalg.fill ins([[CST_0_]] : f32) outs([[RES_]] : memref<128xf32>) +// CHECK: } +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_3_]]{{.}} [1]{{.*}} : memref<128xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_3_]]{{.}} [1] : memref<128xf32> to memref> +// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_1 : memref> to memref> +// CHECK-DAG: [[VAR_5_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tensor.empty() : tensor +// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_5_]] : tensor<128xf32>) outs([[VAR_6_]] : tensor) dimensions = [0] +// CHECK: ([[in_1:%.+]]: f32, [[init_1:%.+]]: f32) { +// CHECK: [[VAR_18_:%.+]] = arith.cmpf ogt, [[in_1]], [[init_1]] : f32 +// CHECK: [[VAR_19_:%.+]] = arith.select [[VAR_18_]], [[in_1]], [[init_1]] : f32 +// CHECK: linalg.yield [[VAR_19_]] : f32 +// CHECK: } +// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor +// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor<128xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[VAR_extracted_]] : f32) outs([[VAR_7_]] : tensor<128xf32>) -> tensor<128xf32> +// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_5_]], [[VAR_8_]] : tensor<128xf32>, tensor<128xf32>) outs([[VAR_5_]] : tensor<128xf32>) { +// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out:%.+]]: f32): +// CHECK: [[VAR_19_1_:%.+]] = arith.subf [[in_1]], [[in_2]] : f32 +// CHECK: linalg.yield [[VAR_19_1_]] : f32 +// CHECK: } -> tensor<128xf32> +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]] : tensor<128xf32>) outs([[VAR_9_]] : tensor<128xf32>) { +// CHECK: ^bb0([[in_1:%.+]]: f32, [[out_1:%.+]]: f32): +// CHECK: [[VAR_19_2_:%.+]] = math.exp [[in_1]] : f32 +// CHECK: linalg.yield [[VAR_19_2_]] : f32 +// CHECK: } -> tensor<128xf32> +// CHECK: [[VAR_11_:%.+]] = bufferization.alloc_tensor() : tensor +// CHECK: [[VAR_inserted_2_:%.+]] = linalg.fill ins([[CST]] : f32) outs([[VAR_11_]] : tensor) -> tensor +// CHECK: [[VAR_reduced_3_:%.+]] = linalg.reduce ins([[VAR_10_]] : tensor<128xf32>) outs([[VAR_inserted_2_]] : tensor) dimensions = [0] +// CHECK: ([[in_1:%.+]]: f32, [[init_1:%.+]]: f32) { +// CHECK: [[VAR_19_3_:%.+]] = arith.addf [[in_1]], [[init_1]] : f32 +// CHECK: linalg.yield [[VAR_19_3_]] : f32 +// CHECK: } +// CHECK-DAG: [[VAR_extracted_4_:%.+]] = tensor.extract [[VAR_reduced_3_]][] : tensor +// CHECK: [[VAR_13_:%.+]] = linalg.fill ins([[VAR_extracted_4_]] : f32) outs([[VAR_7_]] : tensor<128xf32>) -> tensor<128xf32> +// CHECK: [[VAR_14_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_10_]], [[VAR_13_]] : tensor<128xf32>, tensor<128xf32>) outs([[VAR_10_]] : tensor<128xf32>) { +// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out_1:%.+]]: f32): +// CHECK: [[VAR_19_4_:%.+]] = arith.divf [[in_1]], [[in_2]] : f32 +// CHECK: linalg.yield [[VAR_19_4_]] : f32 +// CHECK: } -> tensor<128xf32> +// CHECK: [[VAR_15_:%.+]] = arith.muli %arg10, [[PARAM_3_]] : i32 +// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index +// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_16_]]{{.}}, sizes: [128], strides: [1]{{.*}} : memref to memref<128xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_14_]][0] [[[VAR_3_]]] [1] : tensor<128xf32> to tensor +// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_3_]]{{.}} [1]{{.*}} : memref<128xf32, strided<[1], offset: ?>> to memref> +// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_6_]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir new file mode 100644 index 000000000..9a4ac6365 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir @@ -0,0 +1,208 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @matmul_kernel_0123456789101112131415(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) { + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %4, %c8_i32 : i32 + %8 = arith.divsi %0, %7 : i32 + %9 = arith.muli %8, %c8_i32 : i32 + %10 = arith.subi %2, %9 : i32 + %11 = arith.cmpi slt, %10, %c8_i32 : i32 + %12 = arith.select %11, %10, %c8_i32 : i32 + %13 = arith.remsi %0, %12 : i32 + %14 = arith.addi %9, %13 : i32 + %15 = arith.remsi %0, %7 : i32 + %16 = arith.divsi %15, %12 : i32 + %17 = arith.muli %14, %c128_i32 : i32 + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %19 = tt.splat %17 : i32 -> tensor<128xi32> + %20 = arith.addi %19, %18 : tensor<128xi32> + %21 = arith.muli %16, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %23 = tt.splat %21 : i32 -> tensor<256xi32> + %24 = arith.addi %23, %22 : tensor<256xi32> + %25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %26 = tt.expand_dims %20 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %27 = tt.splat %arg6 : i32 -> tensor<128x1xi32> + %28 = arith.muli %26, %27 : tensor<128x1xi32> + %29 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %30 = tt.splat %arg7 : i32 -> tensor<1x64xi32> + %31 = arith.muli %29, %30 : tensor<1x64xi32> + %32 = tt.broadcast %28 : tensor<128x1xi32> -> tensor<128x64xi32> + %33 = tt.broadcast %31 : tensor<1x64xi32> -> tensor<128x64xi32> + %34 = arith.addi %32, %33 : tensor<128x64xi32> + %35 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> + %36 = tt.addptr %35, %34 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %37 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %38 = tt.splat %arg8 : i32 -> tensor<64x1xi32> + %39 = arith.muli %37, %38 : tensor<64x1xi32> + %40 = tt.expand_dims %24 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %41 = tt.splat %arg9 : i32 -> tensor<1x256xi32> + %42 = arith.muli %40, %41 : tensor<1x256xi32> + %43 = tt.broadcast %39 : tensor<64x1xi32> -> tensor<64x256xi32> + %44 = tt.broadcast %42 : tensor<1x256xi32> -> tensor<64x256xi32> + %45 = arith.addi %43, %44 : tensor<64x256xi32> + %46 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> + %47 = tt.addptr %46, %45 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + %48 = tt.splat %cst : f32 -> tensor<128x256xf32> + %49 = arith.muli %arg7, %c64_i32 : i32 + %50 = tt.splat %49 : i32 -> tensor<128x64xi32> + %51 = arith.muli %arg8, %c64_i32 : i32 + %52 = tt.splat %51 : i32 -> tensor<64x256xi32> + %53:3 = scf.for %arg12 = %c0_i32 to %6 step %c1_i32 iter_args(%arg13 = %48, %arg14 = %36, %arg15 = %47) -> (tensor<128x256xf32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr>) : i32 { + %71 = tt.load %arg14 : tensor<128x64x!tt.ptr> + %72 = tt.load %arg15 : tensor<64x256x!tt.ptr> + %73 = tt.dot %71, %72, %48 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> + %74 = arith.addf %arg13, %73 : tensor<128x256xf32> + %75 = tt.addptr %arg14, %50 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %76 = tt.addptr %arg15, %52 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + scf.yield %74, %75, %76 : tensor<128x256xf32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr> + } + %54 = arith.truncf %53#0 : tensor<128x256xf32> to tensor<128x256xbf16> + %55 = tt.splat %arg10 : i32 -> tensor<128x1xi32> + %56 = arith.muli %55, %26 : tensor<128x1xi32> + %57 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> + %58 = tt.addptr %57, %56 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + %59 = tt.splat %arg11 : i32 -> tensor<1x256xi32> + %60 = arith.muli %59, %40 : tensor<1x256xi32> + %61 = tt.broadcast %58 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> + %62 = tt.broadcast %60 : tensor<1x256xi32> -> tensor<128x256xi32> + %63 = tt.addptr %61, %62 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %64 = tt.splat %arg3 : i32 -> tensor<128x1xi32> + %65 = arith.cmpi slt, %26, %64 : tensor<128x1xi32> + %66 = tt.splat %arg4 : i32 -> tensor<1x256xi32> + %67 = arith.cmpi slt, %40, %66 : tensor<1x256xi32> + %68 = tt.broadcast %65 : tensor<128x1xi1> -> tensor<128x256xi1> + %69 = tt.broadcast %67 : tensor<1x256xi1> -> tensor<128x256xi1> + %70 = arith.andi %68, %69 : tensor<128x256xi1> + tt.store %63, %54, %70 : tensor<128x256x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @matmul_kernel_0123456789101112131415 +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32, [[PARAM_14_:%.+]]: i32, [[PARAM_15_:%.+]]: i32, [[PARAM_16_:%.+]]: i32, [[PARAM_17_:%.+]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "mix"} { +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK-DAG: [[CST_128_1_:%.+]] = arith.constant 128 : i32 +// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_127_:%.+]] = arith.constant 127 : i32 +// CHECK-DAG: [[CST_255_:%.+]] = arith.constant 255 : i32 +// CHECK-DAG: [[CST_63_:%.+]] = arith.constant 63 : i32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x256xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<128x256xf32>) -> tensor<128x256xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[PARAM_3_]], [[CST_127_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = arith.divsi [[VAR_2_]], [[CST_128_1_]] : i32 +// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[PARAM_4_]], [[CST_255_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_5_:%.+]] = arith.divsi [[VAR_4_]], [[CST_256_1_]] : i32 +// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[PARAM_5_]], [[CST_63_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = arith.divsi [[VAR_6_]], [[CST_64_]] : i32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.muli [[VAR_5_]], [[CST_8_]] : i32 +// CHECK: [[VAR_9_:%.+]] = arith.divsi [[PARAM_15_]], [[VAR_8_]] : i32 +// CHECK: [[VAR_10_:%.+]] = arith.muli [[VAR_9_]], [[CST_8_]] : i32 +// CHECK: [[VAR_11_:%.+]] = arith.subi [[VAR_3_]], [[VAR_10_]] : i32 +// CHECK: [[VAR_12_:%.+]] = arith.cmpi slt, [[VAR_11_]], [[CST_8_]] : i32 +// CHECK: [[VAR_12_1:%.+]] = arith.select [[VAR_12_]], [[VAR_11_]], [[CST_8_]] : i32 +// CHECK: [[VAR_13_:%.+]] = arith.remsi [[PARAM_15_]], [[VAR_12_1]] : i32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.addi [[VAR_10_]], [[VAR_13_]] : i32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.remsi [[PARAM_15_]], [[VAR_8_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.divsi [[VAR_15_]], [[VAR_12_1]] : i32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_14_]], [[CST_128_1_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_18_:%.+]] = arith.muli [[VAR_16_]], [[CST_256_1_]] : i32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[VAR_17_]] : i32 to index +// CHECK-DAG: [[VAR_20_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_21_:%.+]] = arith.muli [[VAR_19_]], [[VAR_20_]] : index +// CHECK-DAG: [[VAR_22_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[PARAM_8_]] : i32 to index +// CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[VAR_18_]] : i32 to index +// CHECK-DAG: [[VAR_25_:%.+]] = arith.index_cast [[PARAM_9_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.muli [[VAR_24_]], [[VAR_25_]] : index +// CHECK-DAG: [[VAR_27_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_]] : i32 +// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_]] : i32 +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_21_]]{{.}}, sizes: [128, 64], strides: {{.}}[[VAR_20_]], [[VAR_22_]]{{.}} : memref to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_26_]]{{.}}, sizes: [64, 256], strides: {{.}}[[VAR_23_]], [[VAR_25_]]{{.}} : memref to memref<64x256xbf16, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_29_:%.+]]:7 = scf.for [[VAR_arg18_:%.+]] = [[CST_0_1_]] to [[VAR_7_]] step [[CST_1_]] iter_args([[VAR_arg19_:%.+]] = [[VAR_1_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg21_:%.+]] = [[VAR_reinterpret_cast_]]_0, [[VAR_arg22_:%.+]] = [[VAR_21_]], [[VAR_arg23_:%.+]] = [[CST_0_]], [[VAR_arg24_:%.+]] = [[VAR_26_]], [[VAR_arg25_:%.+]] = [[CST_0_]]) -> (tensor<128x256xf32>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<64x256xbf16, strided<[?, ?], offset: ?>>, index, index, index, index) : i32 { +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> +// CHECK: memref.copy [[VAR_arg20_]], [[RES_]] : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16> +// CHECK-DAG: [[VAR_51_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<64x256xbf16> +// CHECK: memref.copy [[VAR_arg21_]], [[RES_1_]] : memref<64x256xbf16, strided<[?, ?], offset: ?>> to memref<64x256xbf16> +// CHECK-DAG: [[VAR_52_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<64x256xbf16> +// CHECK: [[VAR_55_:%.+]] = linalg.matmul {input_precison = "tf32"} ins([[VAR_51_]], [[VAR_52_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_1_]] : tensor<128x256xf32>) -> tensor<128x256xf32> +// CHECK: [[VAR_57_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg19_]], [[VAR_55_]] : tensor<128x256xf32>, tensor<128x256xf32>) outs([[VAR_arg19_]] : tensor<128x256xf32>) { +// CHECK: ^bb0([[in_:.+]]: f32, [[in_1:.+]]: f32, [[out_:.+]]: f32): +// CHECK: [[VAR_64_1_:%.+]] = arith.addf [[in_]], [[in_1]] : f32 +// CHECK: linalg.yield [[VAR_64_1_]] : f32 +// CHECK: } -> tensor<128x256xf32> +// CHECK: [[VAR_58_:%.+]] = arith.index_cast [[VAR_27_]] : i32 to index +// CHECK: [[VAR_59_:%.+]] = arith.addi [[VAR_arg22_]], [[VAR_58_]] : index +// CHECK: [[VAR_60_:%.+]] = arith.addi [[VAR_59_]], [[VAR_arg23_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_3_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_60_]]{{.}}, sizes: [128, 64], strides: {{.}}[[VAR_20_]], [[VAR_22_]]{{.}} : memref to memref<128x64xbf16, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_61_:%.+]] = arith.index_cast [[VAR_28_]] : i32 to index +// CHECK: [[VAR_62_:%.+]] = arith.addi [[VAR_arg24_]], [[VAR_61_]] : index +// CHECK: [[VAR_63_:%.+]] = arith.addi [[VAR_62_]], [[VAR_arg25_]] : index +// CHECK: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_63_]]{{.}}, sizes: [64, 256], strides: {{.}}[[VAR_23_]], [[VAR_25_]]{{.}} : memref to memref<64x256xbf16, strided<[?, ?], offset: ?>> +// CHECK: scf.yield [[VAR_57_]], [[VAR_reinterpret_cast_3_]], [[VAR_reinterpret_cast_4_]], [[VAR_60_]], [[CST_0_]], [[VAR_63_]], [[CST_0_]] : tensor<128x256xf32>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<64x256xbf16, strided<[?, ?], offset: ?>>, index, index, index, index +// CHECK: } +// CHECK: [[VAR_30_:%.+]] = tensor.empty() : tensor<128x256xbf16> +// CHECK: [[VAR_31_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_29_]]#0 : tensor<128x256xf32>) outs([[VAR_30_]] : tensor<128x256xbf16>) { +// CHECK: ^bb0([[in_:.+]]: f32, [[out_:.+]]: bf16): +// CHECK: [[VAR_51_1_:%.+]] = arith.truncf [[in_]] : f32 to bf16 +// CHECK: linalg.yield [[VAR_51_1_]] : bf16 +// CHECK: } -> tensor<128x256xbf16> +// CHECK-DAG: [[VAR_32_:%.+]] = arith.index_cast [[PARAM_10_]] : i32 to index +// CHECK-DAG: [[VAR_33_:%.+]] = arith.muli [[VAR_19_]], [[VAR_32_]] : index +// CHECK-DAG: [[VAR_34_:%.+]] = arith.index_cast [[PARAM_11_]] : i32 to index +// CHECK-DAG: [[VAR_35_:%.+]] = arith.muli [[VAR_24_]], [[VAR_34_]] : index +// CHECK: [[VAR_38_:%.+]] = arith.addi [[VAR_33_]], [[VAR_35_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_38_]]{{.}}, sizes: [128, 256], strides: {{.}}[[VAR_32_]], [[VAR_34_]]{{.}} : memref to memref<128x256xbf16, strided<[?, ?], offset: ?>> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_40_:%.+]] = arith.addi [[VAR_19_]], [[CST_128_]] : index +// CHECK-DAG: [[VAR_41_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_42_0_:%.+]] = arith.maxsi [[VAR_19_]], [[VAR_41_]] : index +// CHECK: [[VAR_42_:%.+]] = arith.minsi [[VAR_40_]], [[VAR_42_0_]] : index +// CHECK-DAG: [[VAR_43_:%.+]] = arith.subi [[VAR_42_]], [[VAR_19_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_45_:%.+]] = arith.addi [[VAR_24_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_46_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_47_0_:%.+]] = arith.maxsi [[VAR_24_]], [[VAR_46_]] : index +// CHECK: [[VAR_47_:%.+]] = arith.minsi [[VAR_45_]], [[VAR_47_0_]] : index +// CHECK-DAG: [[VAR_48_:%.+]] = arith.subi [[VAR_47_]], [[VAR_24_]] : index +// CHECK-DAG: [[VAR_49_:%.+]] = arith.minsi [[VAR_43_]], [[CST_128_]] : index +// CHECK: [[VAR_50_:%.+]] = arith.minsi [[VAR_48_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_31_]][0, 0] {{.}}[[VAR_49_]], [[VAR_50_]]{{.}} [1, 1] : tensor<128x256xbf16> to tensor +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_49_]], [[VAR_50_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[?, ?], offset: ?>> to memref> +// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir new file mode 100644 index 000000000..13019595f --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir @@ -0,0 +1,164 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @_layer_norm_bwd_dwdb_0123456(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: i32, %arg5: i32) { + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %3 = tt.splat %1 : i32 -> tensor<256xi32> + %4 = arith.addi %3, %2 : tensor<256xi32> + %5 = tt.splat %cst : f32 -> tensor<256x256xf32> + %6 = tt.splat %arg4 : i32 -> tensor<256x1xi32> + %7 = tt.expand_dims %4 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %8 = tt.splat %arg5 : i32 -> tensor<1x256xi32> + %9 = arith.cmpi slt, %7, %8 : tensor<1x256xi32> + %10 = tt.broadcast %9 : tensor<1x256xi1> -> tensor<256x256xi1> + %11 = tt.splat %arg5 : i32 -> tensor<256x1xi32> + %12 = tt.broadcast %7 : tensor<1x256xi32> -> tensor<256x256xi32> + %13 = tt.splat %arg0 : !tt.ptr -> tensor<256x256x!tt.ptr> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<256x256x!tt.ptr> + %15:2 = scf.for %arg6 = %c0_i32 to %arg4 step %c256_i32 iter_args(%arg7 = %5, %arg8 = %5) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { + %24 = tt.splat %arg6 : i32 -> tensor<256xi32> + %25 = arith.addi %24, %2 : tensor<256xi32> + %26 = tt.expand_dims %25 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> + %27 = arith.cmpi slt, %26, %6 : tensor<256x1xi32> + %28 = tt.broadcast %27 : tensor<256x1xi1> -> tensor<256x256xi1> + %29 = arith.andi %28, %10 : tensor<256x256xi1> + %30 = arith.muli %26, %11 : tensor<256x1xi32> + %31 = tt.broadcast %30 : tensor<256x1xi32> -> tensor<256x256xi32> + %32 = arith.addi %31, %12 : tensor<256x256xi32> + %33 = tt.addptr %13, %32 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> + %34 = tt.load %33, %29, %5 : tensor<256x256x!tt.ptr> + %35 = arith.addf %arg7, %34 : tensor<256x256xf32> + %36 = tt.addptr %14, %32 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> + %37 = tt.load %36, %29, %5 : tensor<256x256x!tt.ptr> + %38 = arith.addf %arg8, %37 : tensor<256x256xf32> + scf.yield %35, %38 : tensor<256x256xf32>, tensor<256x256xf32> + } + %16 = "tt.reduce"(%15#0) ({ + ^bb0(%arg6: f32, %arg7: f32): + %24 = arith.addf %arg6, %arg7 : f32 + tt.reduce.return %24 : f32 + }) {axis = 0 : i32} : (tensor<256x256xf32>) -> tensor<256xf32> + %17 = "tt.reduce"(%15#1) ({ + ^bb0(%arg6: f32, %arg7: f32): + %24 = arith.addf %arg6, %arg7 : f32 + tt.reduce.return %24 : f32 + }) {axis = 0 : i32} : (tensor<256x256xf32>) -> tensor<256xf32> + %18 = tt.splat %arg5 : i32 -> tensor<256xi32> + %19 = arith.cmpi slt, %4, %18 : tensor<256xi32> + %20 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr> + %21 = tt.addptr %20, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %21, %16, %19 : tensor<256x!tt.ptr> + %22 = tt.splat %arg3 : !tt.ptr -> tensor<256x!tt.ptr> + %23 = tt.addptr %22, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %23, %17, %19 : tensor<256x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @_layer_norm_bwd_dwdb_0123456 +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_2_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_3_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<256x256xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<256x256xf32>) -> tensor<256x256xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_9_]], [[CST_256_1_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg9_:%.+]] = [[CST_0_]] to [[PARAM_4_]] step [[CST_256_1_]] iter_args([[VAR_arg10_:%.+]] = [[VAR_1_]], [[VAR_arg11_:%.+]] = [[VAR_1_]]) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { +// CHECK-DAG: [[VAR_20_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index +// CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[VAR_20_]], [[VAR_21_]] : index +// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index +// CHECK: [[VAR_24_:%.+]] = arith.addi [[VAR_22_]], [[VAR_23_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_24_]]{{.}}, sizes: [256, 256], strides: {{.}}[[VAR_21_]], 1] : memref to memref<256x256xf32, strided<[?, 1], offset: ?>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<256x256xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.addi [[VAR_20_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_27_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_28_0_:%.+]] = arith.maxsi [[VAR_20_]], [[VAR_27_]] : index +// CHECK: [[VAR_28_:%.+]] = arith.minsi [[VAR_26_]], [[VAR_28_0_]] : index +// CHECK-DAG: [[VAR_29_:%.+]] = arith.subi [[VAR_28_]], [[VAR_20_]] : index +// CHECK-DAG: [[VAR_31_:%.+]] = arith.addi [[VAR_23_]], [[CST_256_]] : index + +// CHECK: [[VAR_33_0_:%.+]] = arith.maxsi [[VAR_23_]], [[VAR_21_]] : index +// CHECK: [[VAR_33_:%.+]] = arith.minsi [[VAR_31_]], [[VAR_33_0_]] : index + +// CHECK-DAG: [[VAR_34_:%.+]] = arith.subi [[VAR_33_]], [[VAR_23_]] : index + +// CHECK-DAG: [[VAR_35_:%.+]] = arith.minsi [[VAR_29_]], [[CST_256_]] : index +// CHECK: [[VAR_36_:%.+]] = arith.minsi [[VAR_34_]], [[CST_256_]] : index + +// CHECK-DAG: [[VAR_37_:%.+]] = arith.cmpi slt, [[VAR_35_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_38_:%.+]] = arith.cmpi slt, [[VAR_36_]], [[CST_256_]] : index +// CHECK: [[VAR_39_:%.+]] = arith.ori [[VAR_37_]], [[VAR_38_]] : i1 +// CHECK: scf.if [[VAR_39_]] { +// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_]] : memref<256x256xf32>) +// CHECK: } +// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[VAR_reinterpret_cast_4_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32> to memref> +// CHECK: memref.copy [[VAR_subview_5_]], [[VAR_subview_6_]] : memref> to memref> +// CHECK: [[VAR_40_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256x256xf32> +// CHECK: [[VAR_41_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg10_]], [[VAR_40_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg10_]] : tensor<256x256xf32>) { +// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): +// CHECK: [[VAR_64_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 +// CHECK: linalg.yield [[VAR_64_]] : f32 +// CHECK: } -> tensor<256x256xf32> + +// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_24_]]{{.}}, sizes: [256, 256], strides: [[[VAR_21_]], 1] : memref to memref<256x256xf32, strided<[?, 1], offset: ?>> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<256x256xf32> + +// CHECK: scf.if [[VAR_39_]] { +// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_1_]] : memref<256x256xf32>) +// CHECK: } +// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_1_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32> to memref> +// CHECK: memref.copy [[VAR_subview_9_]], [[VAR_subview_10_]] : memref> to memref> +// CHECK: [[VAR_62_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256x256xf32> +// CHECK: [[VAR_63_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg11_]], [[VAR_62_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg11_]] : tensor<256x256xf32>) { +// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): +// CHECK: [[VAR_64_1_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 +// CHECK: linalg.yield [[VAR_64_1_]] : f32 +// CHECK: } -> tensor<256x256xf32> +// CHECK: scf.yield [[VAR_41_]], [[VAR_63_]] : tensor<256x256xf32>, tensor<256x256xf32> +// CHECK: } +// CHECK: [[VAR_4_:%.+]] = tensor.empty() : tensor<256xf32> +// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_4_]] : tensor<256xf32>) -> tensor<256xf32> +// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_3_]]#0 : tensor<256x256xf32>) outs([[VAR_5_]] : tensor<256xf32>) dimensions = [0] +// CHECK: ([[in_0:.+]]: f32, [[in_1:.+]]: f32) { +// CHECK: [[VAR_20_1_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 +// CHECK: linalg.yield [[VAR_20_1_]] : f32 +// CHECK: } + +// CHECK: [[VAR_reduced_0_:%.+]] = linalg.reduce ins([[VAR_3_]]#1 : tensor<256x256xf32>) outs([[VAR_5_]] : tensor<256xf32>) dimensions = [0] +// CHECK: ([[in_0:.+]]: f32, [[in_1:.+]]: f32) { +// CHECK: [[VAR_20_2_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 +// CHECK: linalg.yield [[VAR_20_2_]] : f32 +// CHECK: } +// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> + +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_8_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_12_0_:%.+]] = arith.maxsi [[VAR_8_]], [[VAR_11_]] : index +// CHECK: [[VAR_12_:%.+]] = arith.minsi [[VAR_10_]], [[VAR_12_0_]] : index +// CHECK: [[VAR_13_:%.+]] = arith.subi [[VAR_12_]], [[VAR_8_]] : index +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_reduced_]][0] {{.}}[[VAR_13_]]{{.}} [1] : tensor<256xf32> to tensor +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_13_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_]] + +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_3_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[VAR_extracted_slice_2_:%.+]] = tensor.extract_slice [[VAR_reduced_0_]][0] {{.}}[[VAR_13_]]{{.}} [1] : tensor<256xf32> to tensor +// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_13_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_2_]] in writable [[VAR_subview_3_]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir new file mode 100644 index 000000000..677a804f9 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir @@ -0,0 +1,311 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @_layer_norm_fwd_fused_0123456789(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: !tt.ptr, %arg5: !tt.ptr, %arg6: i32, %arg7: i32, %arg8: f32) { + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg6 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = tt.splat %cst_0 : f32 -> tensor<256xf32> + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %6 = tt.splat %arg7 : i32 -> tensor<256xi32> + %7 = tt.splat %3 : !tt.ptr -> tensor<256x!tt.ptr> + %8 = scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 iter_args(%arg10 = %4) -> (tensor<256xf32>) : i32 { + %32 = tt.splat %arg9 : i32 -> tensor<256xi32> + %33 = arith.addi %32, %5 : tensor<256xi32> + %34 = arith.cmpi slt, %33, %6 : tensor<256xi32> + %35 = tt.addptr %7, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %36 = tt.load %35, %34, %4 : tensor<256x!tt.ptr> + %37 = arith.addf %arg10, %36 : tensor<256xf32> + scf.yield %37 : tensor<256xf32> + } + %9 = "tt.reduce"(%8) ({ + ^bb0(%arg9: f32, %arg10: f32): + %32 = arith.addf %arg9, %arg10 : f32 + tt.reduce.return %32 : f32 + }) {axis = 0 : i32} : (tensor<256xf32>) -> f32 + %10 = arith.sitofp %arg7 : i32 to f32 + %11 = arith.divf %9, %10 : f32 + %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %13 = tt.splat %arg7 : i32 -> tensor<256xi32> + %14 = tt.splat %3 : !tt.ptr -> tensor<256x!tt.ptr> + %15 = tt.splat %11 : f32 -> tensor<256xf32> + %16 = scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 iter_args(%arg10 = %4) -> (tensor<256xf32>) : i32 { + %32 = tt.splat %arg9 : i32 -> tensor<256xi32> + %33 = arith.addi %32, %12 : tensor<256xi32> + %34 = arith.cmpi slt, %33, %13 : tensor<256xi32> + %35 = tt.addptr %14, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %36 = tt.load %35, %34, %4 : tensor<256x!tt.ptr> + %37 = arith.subf %36, %15 : tensor<256xf32> + %38 = arith.select %34, %37, %4 : tensor<256xi1>, tensor<256xf32> + %39 = arith.mulf %38, %38 : tensor<256xf32> + %40 = arith.addf %arg10, %39 : tensor<256xf32> + scf.yield %40 : tensor<256xf32> + } + %17 = "tt.reduce"(%16) ({ + ^bb0(%arg9: f32, %arg10: f32): + %32 = arith.addf %arg9, %arg10 : f32 + tt.reduce.return %32 : f32 + }) {axis = 0 : i32} : (tensor<256xf32>) -> f32 + %18 = arith.divf %17, %10 : f32 + %19 = arith.addf %18, %arg8 : f32 + %20 = math.sqrt %19 : f32 + %21 = arith.divf %cst, %20 : f32 + %22 = tt.addptr %arg4, %0 : !tt.ptr, i32 + tt.store %22, %11 : !tt.ptr + %23 = tt.addptr %arg5, %0 : !tt.ptr, i32 + tt.store %23, %21 : !tt.ptr + %24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %25 = tt.splat %arg7 : i32 -> tensor<256xi32> + %26 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr> + %27 = tt.splat %arg3 : !tt.ptr -> tensor<256x!tt.ptr> + %28 = tt.splat %3 : !tt.ptr -> tensor<256x!tt.ptr> + %29 = tt.splat %11 : f32 -> tensor<256xf32> + %30 = tt.splat %21 : f32 -> tensor<256xf32> + %31 = tt.splat %2 : !tt.ptr -> tensor<256x!tt.ptr> + scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 : i32 { + %32 = tt.splat %arg9 : i32 -> tensor<256xi32> + %33 = arith.addi %32, %24 : tensor<256xi32> + %34 = arith.cmpi slt, %33, %25 : tensor<256xi32> + %35 = tt.addptr %26, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %36 = tt.load %35, %34 : tensor<256x!tt.ptr> + %37 = tt.addptr %27, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %38 = tt.load %37, %34 : tensor<256x!tt.ptr> + %39 = tt.addptr %28, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %40 = tt.load %39, %34, %4 : tensor<256x!tt.ptr> + %41 = arith.subf %40, %29 : tensor<256xf32> + %42 = arith.mulf %41, %30 : tensor<256xf32> + %43 = arith.mulf %42, %36 : tensor<256xf32> + %44 = arith.addf %43, %38 : tensor<256xf32> + %45 = tt.addptr %31, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %45, %44, %34 : tensor<256x!tt.ptr> + } + tt.return + } +} + +// CHECK-DAG: #map = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @_layer_norm_fwd_fused_0123456789 +// CHECK-SAME: (%[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_2_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_3_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_4_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_5_:%.+]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: f32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32, [[PARAM_14_:%.+]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: [[c256:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[c0_i32:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[c256_i32:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[cst:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[cst_0:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[VAL_0:.*]] = tensor.empty() : tensor<1xf32> +// CHECK: %[[VAL_1:.*]] = linalg.fill ins([[cst_0]] : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<256xf32> +// CHECK: %[[VAL_3:.*]] = linalg.fill ins([[cst]] : f32) outs(%[[VAL_2]] : tensor<256xf32>) -> tensor<256xf32> +// CHECK: %[[VAL_4:.*]] = arith.muli [[PARAM_12_]], [[PARAM_6_]] : i32 +// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : i32 to index +// CHECK: %[[VAL_6:.*]] = scf.for %arg17 = [[c0_i32]] to [[PARAM_7_]] step [[c256_i32]] iter_args(%arg18 = %[[VAL_3]]) -> (tensor<256xf32>) : i32 { +// CHECK: %[[VAL_33:.*]] = arith.index_cast %arg17 : i32 to index +// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_5]], %[[VAL_33]] : index +// CHECK: %reinterpret_cast_9 = memref.reinterpret_cast [[PARAM_0_]] to offset: [%[[VAL_34]]], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> +// CHECK: %alloc = memref.alloc() : memref<256xf32> +// CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_33]], [[c256]] : index +// CHECK: %[[VAL_36:.*]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: %[[VAL_37:.*]] = arith.maxsi %[[VAL_33]], %[[VAL_36]] : index +// CHECK: %[[VAL_38:.*]] = arith.minsi %[[VAL_35]], %[[VAL_37]] : index +// CHECK: %[[VAL_39:.*]] = arith.subi %[[VAL_38]], %[[VAL_33]] : index +// CHECK: %[[VAL_40:.*]] = arith.cmpi slt, %[[VAL_39]], [[c256]] : index +// CHECK: scf.if %[[VAL_40]] { +// CHECK: linalg.fill ins([[cst]] : f32) outs(%alloc : memref<256xf32>) +// CHECK: } +// CHECK: %subview = memref.subview %reinterpret_cast_9[0] [%[[VAL_39]]] [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK: %subview_10 = memref.subview %alloc[0] [%[[VAL_39]]] [1] : memref<256xf32> to memref> +// CHECK: memref.copy %subview, %subview_10 : memref> to memref> +// CHECK: %[[VAL_41:.*]] = bufferization.to_tensor %alloc restrict writable : memref<256xf32> +// CHECK: %[[VAL_42:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg18, %[[VAL_41]] : tensor<256xf32>, tensor<256xf32>) outs(%arg18 : tensor<256xf32>) { +// CHECK: ^bb0(%in: f32, %in_11: f32, %out: f32): +// CHECK: %[[VAL_43:.*]] = arith.addf %in, %in_11 : f32 +// CHECK: linalg.yield %[[VAL_43]] : f32 +// CHECK: } -> tensor<256xf32> +// CHECK: scf.yield %[[VAL_42]] : tensor<256xf32> +// CHECK: } +// CHECK: %[[VAL_7:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[VAL_8:.*]] = linalg.fill ins([[cst]] : f32) outs(%[[VAL_7]] : tensor) -> tensor +// CHECK: %reduced = linalg.reduce ins(%[[VAL_6]] : tensor<256xf32>) outs(%[[VAL_8]] : tensor) dimensions = [0] +// CHECK: (%in: f32, %init: f32) { +// CHECK: %[[VAL_33:.*]] = arith.addf %in, %init : f32 +// CHECK: linalg.yield %[[VAL_33]] : f32 +// CHECK: } +// CHECK: %extracted = tensor.extract %reduced[] : tensor +// CHECK: %[[VAL_9:.*]] = arith.sitofp [[PARAM_7_]] : i32 to f32 +// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%extracted : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_12:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_10]], %[[VAL_11]] : tensor<1xf32>, tensor<1xf32>) outs(%[[VAL_10]] : tensor<1xf32>) { +// CHECK: ^bb0(%in: f32, %in_9: f32, %out: f32): +// CHECK: %[[VAL_33:.*]] = arith.divf %in, %in_9 : f32 +// CHECK: linalg.yield %[[VAL_33]] : f32 +// CHECK: } -> tensor<1xf32> +// CHECK: %extracted_1 = tensor.extract %[[VAL_12]][[[c0]]] : tensor<1xf32> +// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<256xi32> +// CHECK: %[[VAL_14:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_13]] : tensor<256xi32>) { +// CHECK: ^bb0(%out: i32): +// CHECK: %[[VAL_33:.*]] = linalg.index 0 : index +// CHECK: %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : index to i32 +// CHECK: linalg.yield %[[VAL_34]] : i32 +// CHECK: } -> tensor<256xi32> +// CHECK: %[[VAL_15:.*]] = linalg.fill ins([[PARAM_7_]] : i32) outs(%[[VAL_13]] : tensor<256xi32>) -> tensor<256xi32> +// CHECK: %[[VAL_16:.*]] = linalg.fill ins(%extracted_1 : f32) outs(%[[VAL_2]] : tensor<256xf32>) -> tensor<256xf32> +// CHECK: %[[VAL_17:.*]] = scf.for %arg17 = [[c0_i32]] to [[PARAM_7_]] step [[c256_i32]] iter_args(%arg18 = %[[VAL_3]]) -> (tensor<256xf32>) : i32 { +// CHECK: %[[VAL_33:.*]] = linalg.fill ins(%arg17 : i32) outs(%[[VAL_13]] : tensor<256xi32>) -> tensor<256xi32> +// CHECK: %[[VAL_34:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_33]], %[[VAL_14]] : tensor<256xi32>, tensor<256xi32>) outs(%[[VAL_33]] : tensor<256xi32>) { +// CHECK: ^bb0(%in: i32, %in_11: i32, %out: i32): +// CHECK: %[[VAL_50:.*]] = arith.addi %in, %in_11 : i32 +// CHECK: linalg.yield %[[VAL_50:.*]] : i32 +// CHECK: } -> tensor<256xi32> +// CHECK: %[[VAL_35:.*]] = tensor.empty() : tensor<256xi1> +// CHECK: %[[VAL_36:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_34]], %[[VAL_15]] : tensor<256xi32>, tensor<256xi32>) outs(%[[VAL_35]] : tensor<256xi1>) { +// CHECK: ^bb0(%in: i32, %in_11: i32, %out: i1): +// CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %in, %in_11 : i32 +// CHECK: linalg.yield %[[VAL_50:.*]] : i1 +// CHECK: } -> tensor<256xi1> +// CHECK: %[[VAL_37:.*]] = arith.index_cast %arg17 : i32 to index +// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_5]], %[[VAL_37]] : index +// CHECK: %reinterpret_cast_9 = memref.reinterpret_cast [[PARAM_0_]] to offset: [%[[VAL_38]]], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> +// CHECK: %alloc = memref.alloc() : memref<256xf32> +// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_37]], [[c256]] : index +// CHECK: %[[VAL_40:.*]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: %[[VAL_41:.*]] = arith.maxsi %[[VAL_37]], %[[VAL_40]] : index +// CHECK: %[[VAL_42:.*]] = arith.minsi %[[VAL_39]], %[[VAL_41]] : index +// CHECK: %[[VAL_43:.*]] = arith.subi %[[VAL_42]], %[[VAL_37]] : index +// CHECK: %[[VAL_44:.*]] = arith.cmpi slt, %[[VAL_43]], [[c256]] : index +// CHECK: scf.if %[[VAL_44]] { +// CHECK: linalg.fill ins([[cst]] : f32) outs(%alloc : memref<256xf32>) +// CHECK: } +// CHECK: %subview = memref.subview %reinterpret_cast_9[0] [%[[VAL_43]]] [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK: %subview_10 = memref.subview %alloc[0] [%[[VAL_43]]] [1] : memref<256xf32> to memref> +// CHECK: memref.copy %subview, %subview_10 : memref> to memref> +// CHECK: %[[VAL_45:.*]] = bufferization.to_tensor %alloc restrict writable : memref<256xf32> +// CHECK: %[[VAL_46:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_45]], %[[VAL_16]] : tensor<256xf32>, tensor<256xf32>) outs(%[[VAL_45]] : tensor<256xf32>) { +// CHECK: ^bb0(%in: f32, %in_11: f32, %out: f32): +// CHECK: %[[VAL_50:.*]] = arith.subf %in, %in_11 : f32 +// CHECK: linalg.yield %[[VAL_50:.*]] : f32 +// CHECK: } -> tensor<256xf32> +// CHECK: %[[VAL_47:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_36]], %[[VAL_46]], %[[VAL_3]] : tensor<256xi1>, tensor<256xf32>, tensor<256xf32>) outs(%[[VAL_46]] : tensor<256xf32>) { +// CHECK: ^bb0(%in: i1, %in_11: f32, %in_12: f32, %out: f32): +// CHECK: %[[VAL_50:.*]] = arith.select %in, %in_11, %in_12 : f32 +// CHECK: linalg.yield %[[VAL_50:.*]] : f32 +// CHECK: } -> tensor<256xf32> +// CHECK: %[[VAL_48:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_47]], %[[VAL_47]] : tensor<256xf32>, tensor<256xf32>) outs(%[[VAL_47]] : tensor<256xf32>) { +// CHECK: ^bb0(%in: f32, %in_11: f32, %out: f32): +// CHECK: %[[VAL_50:.*]] = arith.mulf %in, %in_11 : f32 +// CHECK: linalg.yield %[[VAL_50:.*]] : f32 +// CHECK: } -> tensor<256xf32> +// CHECK: %[[VAL_49:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg18, %[[VAL_48]] : tensor<256xf32>, tensor<256xf32>) outs(%arg18 : tensor<256xf32>) { +// CHECK: ^bb0(%in: f32, %in_11: f32, %out: f32): +// CHECK: %[[VAL_50:.*]] = arith.addf %in, %in_11 : f32 +// CHECK: linalg.yield %[[VAL_50:.*]] : f32 +// CHECK: } -> tensor<256xf32> +// CHECK: scf.yield %[[VAL_49]] : tensor<256xf32> +// CHECK: } +// CHECK: %[[VAL_18:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[VAL_19:.*]] = linalg.fill ins([[cst]] : f32) outs(%[[VAL_18]] : tensor) -> tensor +// CHECK: %reduced_2 = linalg.reduce ins(%[[VAL_17]] : tensor<256xf32>) outs(%[[VAL_19]] : tensor) dimensions = [0] +// CHECK: (%in: f32, %init: f32) { +// CHECK: %[[VAL_33:.*]] = arith.addf %in, %init : f32 +// CHECK: linalg.yield %[[VAL_33]] : f32 +// CHECK: } +// CHECK: %extracted_3 = tensor.extract %reduced_2[] : tensor +// CHECK: %[[VAL_20:.*]] = linalg.fill ins(%extracted_3 : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_21:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_20]], %[[VAL_11]] : tensor<1xf32>, tensor<1xf32>) outs(%[[VAL_20]] : tensor<1xf32>) { +// CHECK: ^bb0(%in: f32, %in_9: f32, %out: f32): +// CHECK: %[[VAL_33:.*]] = arith.divf %in, %in_9 : f32 +// CHECK: linalg.yield %[[VAL_33]] : f32 +// CHECK: } -> tensor<1xf32> +// CHECK: %extracted_4 = tensor.extract %[[VAL_21]][[[c0]]] : tensor<1xf32> +// CHECK: %[[VAL_22:.*]] = linalg.fill ins(%extracted_4 : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_23:.*]] = linalg.fill ins([[PARAM_8_]] : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_24:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_22]], %[[VAL_23]] : tensor<1xf32>, tensor<1xf32>) outs(%[[VAL_22]] : tensor<1xf32>) { +// CHECK: ^bb0(%in: f32, %in_9: f32, %out: f32): +// CHECK: %[[VAL_33:.*]] = arith.addf %in, %in_9 : f32 +// CHECK: linalg.yield %[[VAL_33]] : f32 +// CHECK: } -> tensor<1xf32> +// CHECK: %extracted_5 = tensor.extract %[[VAL_24]][[[c0]]] : tensor<1xf32> +// CHECK: %[[VAL_25:.*]] = linalg.fill ins(%extracted_5 : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_26:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_25]] : tensor<1xf32>) outs(%[[VAL_25]] : tensor<1xf32>) { +// CHECK: ^bb0(%in: f32, %out: f32): +// CHECK: %[[VAL_33:.*]] = math.sqrt %in : f32 +// CHECK: linalg.yield %[[VAL_33]] : f32 +// CHECK: } -> tensor<1xf32> +// CHECK: %extracted_6 = tensor.extract %[[VAL_26]][[[c0]]] : tensor<1xf32> +// CHECK: %[[VAL_27:.*]] = linalg.fill ins(%extracted_6 : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_28:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_1]], %[[VAL_27]] : tensor<1xf32>, tensor<1xf32>) outs(%[[VAL_1]] : tensor<1xf32>) { +// CHECK: ^bb0(%in: f32, %in_9: f32, %out: f32): +// CHECK: %[[VAL_33:.*]] = arith.divf %in, %in_9 : f32 +// CHECK: linalg.yield %[[VAL_33]] : f32 +// CHECK: } -> tensor<1xf32> +// CHECK: %extracted_7 = tensor.extract %[[VAL_28]][[[c0]]] : tensor<1xf32> +// CHECK: %[[VAL_29:.*]] = arith.index_cast [[PARAM_12_]] : i32 to index +// CHECK: %[[VAL_30:.*]] = linalg.fill ins(%extracted_1 : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %reinterpret_cast = memref.reinterpret_cast [[PARAM_4_]] to offset: [%[[VAL_29]]], sizes: [1], strides: [1] : memref to memref<1xf32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_30]] in writable %reinterpret_cast : (tensor<1xf32>, memref<1xf32, strided<[1], offset: ?>>) -> () +// CHECK: %[[VAL_31:.*]] = linalg.fill ins(%extracted_7 : f32) outs(%[[VAL_0]] : tensor<1xf32>) -> tensor<1xf32> +// CHECK: %reinterpret_cast_8 = memref.reinterpret_cast [[PARAM_5_]] to offset: [%[[VAL_29]]], sizes: [1], strides: [1] : memref to memref<1xf32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_31]] in writable %reinterpret_cast_8 : (tensor<1xf32>, memref<1xf32, strided<[1], offset: ?>>) -> () +// CHECK: %[[VAL_32:.*]] = linalg.fill ins(%extracted_7 : f32) outs(%[[VAL_2]] : tensor<256xf32>) -> tensor<256xf32> +// CHECK: scf.for %arg17 = [[c0_i32]] to [[PARAM_7_]] step [[c256_i32]] : i32 { +// CHECK: %[[VAL_33:.*]] = arith.index_cast %arg17 : i32 to index +// CHECK: %reinterpret_cast_9 = memref.reinterpret_cast [[PARAM_2_]] to offset: [%[[VAL_33]]], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> +// CHECK: %alloc = memref.alloc() : memref<256xf32> +// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], [[c256]] : index +// CHECK: %[[VAL_35:.*]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: %[[VAL_36:.*]] = arith.maxsi %[[VAL_33]], %[[VAL_35]] : index +// CHECK: %[[VAL_37:.*]] = arith.minsi %[[VAL_34]], %[[VAL_36]] : index +// CHECK: %[[VAL_38:.*]] = arith.subi %[[VAL_37]], %[[VAL_33]] : index +// CHECK: %subview = memref.subview %reinterpret_cast_9[0] [%[[VAL_38]]] [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK: %subview_10 = memref.subview %alloc[0] [%[[VAL_38]]] [1] : memref<256xf32> to memref> +// CHECK: memref.copy %subview, %subview_10 : memref> to memref> +// CHECK: %[[VAL_39:.*]] = bufferization.to_tensor %alloc restrict writable : memref<256xf32> +// CHECK: %reinterpret_cast_11 = memref.reinterpret_cast [[PARAM_3_]] to offset: [%[[VAL_33]]], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> +// CHECK: %alloc_12 = memref.alloc() : memref<256xf32> +// CHECK: %subview_13 = memref.subview %reinterpret_cast_11[0] [%[[VAL_38]]] [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK: %subview_14 = memref.subview %alloc_12[0] [%[[VAL_38]]] [1] : memref<256xf32> to memref> +// CHECK: memref.copy %subview_13, %subview_14 : memref> to memref> +// CHECK: %[[VAL_40:.*]] = bufferization.to_tensor %alloc_12 restrict writable : memref<256xf32> +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_5]], %[[VAL_33]] : index +// CHECK: %reinterpret_cast_15 = memref.reinterpret_cast [[PARAM_0_]] to offset: [%[[VAL_41]]], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> +// CHECK: %alloc_16 = memref.alloc() : memref<256xf32> +// CHECK: %[[VAL_42:.*]] = arith.cmpi slt, %[[VAL_38]], [[c256]] : index +// CHECK: scf.if %[[VAL_42]] { +// CHECK: linalg.fill ins([[cst]] : f32) outs(%alloc_16 : memref<256xf32>) +// CHECK: } +// CHECK: %subview_17 = memref.subview %reinterpret_cast_15[0] [%[[VAL_38]]] [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK: %subview_18 = memref.subview %alloc_16[0] [%[[VAL_38]]] [1] : memref<256xf32> to memref> +// CHECK: memref.copy %subview_17, %subview_18 : memref> to memref> +// CHECK: %[[VAL_43:.*]] = bufferization.to_tensor %alloc_16 restrict writable : memref<256xf32> +// CHECK: %[[VAL_44:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_43]], %[[VAL_16]] : tensor<256xf32>, tensor<256xf32>) outs(%[[VAL_43]] : tensor<256xf32>) { +// CHECK: ^bb0(%in: f32, %in_21: f32, %out: f32): +// CHECK: %[[VAL_48:.*]] = arith.subf %in, %in_21 : f32 +// CHECK: linalg.yield %[[VAL_48]] : f32 +// CHECK: } -> tensor<256xf32> +// CHECK: %[[VAR_45:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_44]], %[[VAL_32]] : tensor<256xf32>, tensor<256xf32>) outs(%[[VAL_44]] : tensor<256xf32>) { +// CHECK: ^bb0(%in: f32, %in_21: f32, %out: f32): +// CHECK: %[[VAL_48:.*]] = arith.mulf %in, %in_21 : f32 +// CHECK: linalg.yield %[[VAL_48]] : f32 +// CHECK: } -> tensor<256xf32> +// CHECK: %[[VAR_46:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAR_45]], %[[VAL_39]] : tensor<256xf32>, tensor<256xf32>) outs(%[[VAR_45]] : tensor<256xf32>) { +// CHECK: ^bb0(%in: f32, %in_21: f32, %out: f32): +// CHECK: %[[VAL_48:.*]] = arith.mulf %in, %in_21 : f32 +// CHECK: linalg.yield %[[VAL_48]] : f32 +// CHECK: } -> tensor<256xf32> +// CHECK: %[[VAL_47:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAR_46]], %[[VAL_40]] : tensor<256xf32>, tensor<256xf32>) outs(%[[VAR_46]] : tensor<256xf32>) { +// CHECK: ^bb0(%in: f32, %in_21: f32, %out: f32): +// CHECK: %[[VAL_48:.*]] = arith.addf %in, %in_21 : f32 +// CHECK: linalg.yield %[[VAL_48]] : f32 +// CHECK: } -> tensor<256xf32> +// CHECK: %reinterpret_cast_19 = memref.reinterpret_cast [[PARAM_1_]] to offset: [%[[VAL_41]]], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> +// CHECK: %extracted_slice = tensor.extract_slice %[[VAL_47]][0] [%[[VAL_38]]] [1] : tensor<256xf32> to tensor +// CHECK: %subview_20 = memref.subview %reinterpret_cast_19[0] [%[[VAL_38]]] [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK: bufferization.materialize_in_destination %extracted_slice in writable %subview_20 : (tensor, memref>) -> () +// CHECK: } +// CHECK: return +// CHECK: } \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/libdevice_flip.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/libdevice_flip.mlir new file mode 100644 index 000000000..cde333fbe --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/libdevice_flip.mlir @@ -0,0 +1,36 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { +tt.func public @fn_npu_flip(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<2> : tensor<4x8x8xi32> + %cst_0 = arith.constant dense<8> : tensor<1x8x1xi32> + %cst_1 = arith.constant dense<8> : tensor<4x1x1xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<4x1xi32> -> tensor<4x1x1xi32> + %4 = arith.muli %3, %cst_1 : tensor<4x1x1xi32> + %5 = arith.muli %4, %cst_1 : tensor<4x1x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %8 = arith.muli %7, %cst_0 : tensor<1x8x1xi32> + %9 = tt.broadcast %5 : tensor<4x1x1xi32> -> tensor<4x8x1xi32> + %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<4x8x1xi32> + %11 = arith.addi %9, %10 : tensor<4x8x1xi32> + %12 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x8xi32> -> tensor<1x1x8xi32> + %13 = tt.broadcast %11 : tensor<4x8x1xi32> -> tensor<4x8x8xi32> + %14 = tt.broadcast %12 : tensor<1x1x8xi32> -> tensor<4x8x8xi32> + %15 = arith.addi %13, %14 : tensor<4x8x8xi32> + %16 = tt.splat %arg1 : !tt.ptr -> tensor<4x8x8x!tt.ptr> + %17 = tt.addptr %16, %15 : tensor<4x8x8x!tt.ptr>, tensor<4x8x8xi32> + %18 = tt.load %17 : tensor<4x8x8x!tt.ptr> + %19 = tt.extern_elementwise %18, %cst {libname = "", libpath = "", pure = true, symbol = "__hmf_flipf"} : (tensor<4x8x8xf32>, tensor<4x8x8xi32>) -> tensor<4x8x8xf32> + %20 = tt.splat %arg0 : !tt.ptr -> tensor<4x8x8x!tt.ptr> + %21 = tt.addptr %20, %15 : tensor<4x8x8x!tt.ptr>, tensor<4x8x8xi32> + tt.store %21, %19 : tensor<4x8x8x!tt.ptr> + tt.return +} +} + +//CHECK: func.func private @__hmf_flipf(f32, i32) -> f32 attributes {llvm.readnone} +//CHECK-NOT: tt.extern_elementwise +//CHECK: %[[RESULT:.*]] = linalg.map { func.call {callee = @__hmf_flipf} } ins(%[[TENSOR:.*]], %[[DIM:.*]] : tensor<4x8x8xf32>, tensor<4x8x8xi32>) outs(%[[TENSOR]] : tensor<4x8x8xf32>) diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/libdevice_round.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/libdevice_round.mlir new file mode 100644 index 000000000..91b19fb75 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/libdevice_round.mlir @@ -0,0 +1,32 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { +tt.func public @test_round(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 :i32 + %c32_i32 = arith.constant 32 :i32 + %c0_i32 = arith.constant 0 :i32 + %c64_i32 = arith.constant 64 :i32 + %c2048_i32 = arith.constant 2048 :i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c2048_i32 :i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 : i32 { + %5 = arith.muli %arg2, %c64_i32 : i32 + %6 = arith.addi %1, %5 : i32 + %7 = tt.splat %6 : i32 -> tensor<64xi32> + %8 = arith.addi %7, %2 : tensor<64xi32> + %9 = tt.addptr %3, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + %10 = tt.load %9 : tensor<64x!tt.ptr> + %11 = tt.extern_elementwise %10 {libname ="", libpath= "", pure = true, symbol= "__hmf_roundf"} : (tensor<64xf32>) -> tensor<64xf32> + %12 = tt.addptr %4, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %12, %11 : tensor<64x!tt.ptr> + } + tt.return +} +} + + +//CHECK: func.func private @__hmf_roundf(f32) -> f32 attributes {llvm.readnone} +//CHECK: %[[RESULT:.*]] = linalg.map { func.call {callee = @__hmf_roundf} } ins(%[[SOURCE:.*]] : tensor<64xf32>) outs(%[[SOURCE]] : tensor<64xf32>) +//CHECK: bufferization.materialize_in_destination %[[RESULT]] in writable %[[DST:.*]] : (tensor<64xf32>, memref<64xf32, strided<[1], offset: ?>>) -> () \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/linear_compress_bwd_kernel.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/linear_compress_bwd_kernel.mlir new file mode 100644 index 000000000..7e8a0ab18 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/linear_compress_bwd_kernel.mlir @@ -0,0 +1,162 @@ +// RUN: triton-adapter-opt --discrete-mask-access-conversion "--triton-to-linalg=global-kernel=false named-ops=True" %s | FileCheck %s + +module { + tt.func public @linear_compress_bwd_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + // CHECK-LABEL: func.func @linear_compress_bwd_kernel + %cst = arith.constant dense<0.000000e+00> : tensor<16x32x4xf16> + %c256_i32 = arith.constant 256 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf32> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<16x128xf32> + %cst_2 = arith.constant dense<128> : tensor<32xi32> + %cst_3 = arith.constant dense<32> : tensor<32xi32> + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c32_i64 = arith.constant 32 : i64 + %cst_4 = arith.constant dense<128> : tensor<4xi32> + %c4_i32 = arith.constant 4 : i32 + %cst_5 = arith.constant dense<16> : tensor<16xi32> + %c16_i32 = arith.constant 16 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c8_i32 : i32 + %2 = arith.remsi %0, %c8_i32 : i32 + %3 = tt.get_program_id y : i32 + %4 = tt.get_program_id z : i32 + %5 = arith.divsi %4, %c32_i32 : i32 + %6 = arith.remsi %4, %c32_i32 : i32 + %7 = tt.addptr %arg5, %1 : !tt.ptr, i32 + %8 = tt.load %7 : !tt.ptr + %9 = tt.addptr %7, %c1_i32 : !tt.ptr, i32 + %10 = tt.load %9 : !tt.ptr + %11 = arith.subi %10, %8 : i32 + %12 = tt.addptr %arg6, %1 : !tt.ptr, i32 + %13 = tt.load %12 : !tt.ptr + %14 = tt.addptr %12, %c1_i32 : !tt.ptr, i32 + %15 = tt.load %14 : !tt.ptr + %16 = arith.subi %15, %13 : i32 + %17 = arith.muli %3, %c16_i32 : i32 + %18 = arith.cmpi sge, %17, %16 : i32 + cf.cond_br %18, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + tt.return + ^bb2: // pred: ^bb0 + %19 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %20 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %21 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %22 = arith.muli %2, %arg8 : i32 + %23 = tt.addptr %arg3, %22 : !tt.ptr, i32 + %24 = arith.muli %8, %arg7 : i32 + %25 = tt.addptr %23, %24 : !tt.ptr, i32 + %26 = arith.muli %3, %c256_i32 : i32 + %27 = arith.muli %21, %cst_5 : tensor<16xi32> + %28 = tt.splat %26 : i32 -> tensor<16xi32> + %29 = arith.addi %28, %27 : tensor<16xi32> + %30 = tt.expand_dims %29 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %31 = tt.expand_dims %19 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %32 = tt.broadcast %30 : tensor<16x1xi32> -> tensor<16x32xi32> + %33 = tt.broadcast %31 : tensor<1x32xi32> -> tensor<16x32xi32> + %34 = arith.addi %32, %33 : tensor<16x32xi32> + %35 = tt.expand_dims %34 {axis = 2 : i32} : tensor<16x32xi32> -> tensor<16x32x1xi32> + %36 = tt.splat %arg7 : i32 -> tensor<16x32x1xi32> + %37 = arith.muli %35, %36 : tensor<16x32x1xi32> + %38 = tt.splat %25 : !tt.ptr -> tensor<16x32x1x!tt.ptr> + %39 = tt.addptr %38, %37 : tensor<16x32x1x!tt.ptr>, tensor<16x32x1xi32> + %40 = arith.muli %6, %c4_i32 : i32 + %41 = tt.splat %40 : i32 -> tensor<4xi32> + %42 = arith.addi %41, %20 : tensor<4xi32> + %43 = tt.expand_dims %42 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %45 = tt.broadcast %39 : tensor<16x32x1x!tt.ptr> -> tensor<16x32x4x!tt.ptr> + %46 = tt.broadcast %44 : tensor<1x1x4xi32> -> tensor<16x32x4xi32> + %47 = tt.addptr %45, %46 : tensor<16x32x4x!tt.ptr>, tensor<16x32x4xi32> + %48 = tt.splat %11 : i32 -> tensor<16x32xi32> + %49 = arith.cmpi slt, %34, %48 : tensor<16x32xi32> + %50 = tt.expand_dims %49 {axis = 2 : i32} : tensor<16x32xi1> -> tensor<16x32x1xi1> + %51 = arith.cmpi slt, %42, %cst_4 : tensor<4xi32> + %52 = tt.expand_dims %51 {axis = 0 : i32} : tensor<4xi1> -> tensor<1x4xi1> + %53 = tt.expand_dims %52 {axis = 1 : i32} : tensor<1x4xi1> -> tensor<1x1x4xi1> + %54 = tt.broadcast %50 : tensor<16x32x1xi1> -> tensor<16x32x4xi1> + %55 = tt.broadcast %53 : tensor<1x1x4xi1> -> tensor<16x32x4xi1> + %56 = arith.andi %54, %55 : tensor<16x32x4xi1> + %57 = arith.muli %2, %arg13 : i32 + %58 = tt.addptr %arg0, %57 : !tt.ptr, i32 + %59 = arith.muli %8, %arg12 : i32 + %60 = tt.addptr %58, %59 : !tt.ptr, i32 + %61 = tt.splat %arg12 : i32 -> tensor<16x32x1xi32> + %62 = arith.muli %35, %61 : tensor<16x32x1xi32> + %63 = tt.splat %60 : !tt.ptr -> tensor<16x32x1x!tt.ptr> + %64 = tt.addptr %63, %62 : tensor<16x32x1x!tt.ptr>, tensor<16x32x1xi32> + %65 = tt.broadcast %64 : tensor<16x32x1x!tt.ptr> -> tensor<16x32x4x!tt.ptr> + %66 = tt.addptr %65, %46 : tensor<16x32x4x!tt.ptr>, tensor<16x32x4xi32> + %67 = arith.muli %2, %arg9 : i32 + %68 = tt.addptr %arg4, %67 : !tt.ptr, i32 + %69 = arith.muli %5, %c32_i32 : i32 + %70 = arith.extsi %arg10 : i32 to i64 + %71 = arith.extsi %arg11 : i32 to i64 + %72 = tt.make_tensor_ptr %68, [%c32_i64, %c128_i64, %c128_i64], [%70, %71, %c1_i64], [%c0_i32, %40, %69] {order = array} : > + %73 = arith.muli %2, %arg14 : i32 + %74 = tt.addptr %arg2, %73 : !tt.ptr, i32 + %75 = tt.expand_dims %19 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> + %76 = tt.expand_dims %75 {axis = 2 : i32} : tensor<32x1xi32> -> tensor<32x1x1xi32> + %77 = tt.splat %arg15 : i32 -> tensor<32x1x1xi32> + %78 = arith.muli %76, %77 : tensor<32x1x1xi32> + %79 = tt.splat %74 : !tt.ptr -> tensor<32x1x1x!tt.ptr> + %80 = tt.addptr %79, %78 : tensor<32x1x1x!tt.ptr>, tensor<32x1x1xi32> + %81 = tt.expand_dims %43 {axis = 2 : i32} : tensor<1x4xi32> -> tensor<1x4x1xi32> + %82 = tt.splat %arg16 : i32 -> tensor<1x4x1xi32> + %83 = arith.muli %81, %82 : tensor<1x4x1xi32> + %84 = tt.broadcast %80 : tensor<32x1x1x!tt.ptr> -> tensor<32x4x1x!tt.ptr> + %85 = tt.broadcast %83 : tensor<1x4x1xi32> -> tensor<32x4x1xi32> + %86 = tt.addptr %84, %85 : tensor<32x4x1x!tt.ptr>, tensor<32x4x1xi32> + %87 = tt.splat %69 : i32 -> tensor<32xi32> + %88 = arith.addi %87, %19 : tensor<32xi32> + %89 = tt.expand_dims %88 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %90 = tt.expand_dims %89 {axis = 1 : i32} : tensor<1x32xi32> -> tensor<1x1x32xi32> + %91 = tt.broadcast %86 : tensor<32x4x1x!tt.ptr> -> tensor<32x4x32x!tt.ptr> + %92 = tt.broadcast %90 : tensor<1x1x32xi32> -> tensor<32x4x32xi32> + %93 = tt.addptr %91, %92 : tensor<32x4x32x!tt.ptr>, tensor<32x4x32xi32> + %94 = arith.cmpi slt, %19, %cst_3 : tensor<32xi32> + %95 = tt.expand_dims %94 {axis = 1 : i32} : tensor<32xi1> -> tensor<32x1xi1> + %96 = tt.expand_dims %95 {axis = 2 : i32} : tensor<32x1xi1> -> tensor<32x1x1xi1> + %97 = tt.expand_dims %52 {axis = 2 : i32} : tensor<1x4xi1> -> tensor<1x4x1xi1> + %98 = tt.broadcast %96 : tensor<32x1x1xi1> -> tensor<32x4x1xi1> + %99 = tt.broadcast %97 : tensor<1x4x1xi1> -> tensor<32x4x1xi1> + %100 = arith.andi %98, %99 : tensor<32x4x1xi1> + %101 = arith.cmpi slt, %88, %cst_2 : tensor<32xi32> + %102 = tt.expand_dims %101 {axis = 0 : i32} : tensor<32xi1> -> tensor<1x32xi1> + %103 = tt.expand_dims %102 {axis = 1 : i32} : tensor<1x32xi1> -> tensor<1x1x32xi1> + %104 = tt.broadcast %100 : tensor<32x4x1xi1> -> tensor<32x4x32xi1> + %105 = tt.broadcast %103 : tensor<1x1x32xi1> -> tensor<32x4x32xi1> + %106 = arith.andi %104, %105 : tensor<32x4x32xi1> + %107 = arith.muli %13, %arg17 : i32 + %108 = tt.addptr %arg1, %107 : !tt.ptr, i32 + %109 = arith.muli %2, %arg18 : i32 + %110 = tt.addptr %108, %109 : !tt.ptr, i32 + %111 = arith.extsi %16 : i32 to i64 + %112 = arith.extsi %arg17 : i32 to i64 + %113 = tt.make_tensor_ptr %110, [%111, %c128_i64], [%112, %c1_i64], [%17, %69] {order = array} : > + %114 = tt.load %113 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %115 = tt.load %72 {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + %116 = tt.reshape %115 : tensor<32x4x32xf16> -> tensor<128x32xf16> + tt.annotation %116 {maybeUnCollapsibleReshape} : tensor<128x32xf16> + %117 = tt.trans %116 {order = array} : tensor<128x32xf16> -> tensor<32x128xf16> + %118 = tt.dot %114, %117, %cst_1 : tensor<16x32xf16> * tensor<32x128xf16> -> tensor<16x128xf32> + %119 = tt.reshape %118 : tensor<16x128xf32> -> tensor<16x32x4xf32> + %120 = tt.atomic_rmw fadd, acq_rel, gpu, %66, %119, %56 : (tensor<16x32x4x!tt.ptr>, tensor<16x32x4xf32>, tensor<16x32x4xi1>) -> tensor<16x32x4xf32> + // CHECK: [[VAR0:%[a-zA-Z0-9_]+]] = tensor.reshape {{%[a-zA-Z0-9_]+}}({{%[a-zA-Z0-9_]+}}) : (tensor<16x128xf32>, tensor<3xi64>) -> tensor<16x32x4xf32> + // CHECK-NEXT: [[VAR1:%[a-zA-Z0-9_]+]] = arith.select {{%[a-zA-Z0-9_]+}}, [[VAR0]], {{%[a-zA-Z0-9_]+}} : tensor<16x32x4xi1>, tensor<16x32x4xf32> + // CHECK-NEXT: [[VAR2:%[a-zA-Z0-9_]+]] = bufferization.to_memref [[VAR1]] : memref<16x32x4xf32> + // CHECK-NEXT: linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins({{%[a-zA-Z0-9_]+}}, [[VAR2]] : + %121 = tt.load %47, %56, %cst : tensor<16x32x4x!tt.ptr> + %122 = tt.reshape %121 : tensor<16x32x4xf16> -> tensor<16x128xf16> + tt.annotation %122 {maybeUnCollapsibleReshape} : tensor<16x128xf16> + %123 = tt.trans %122 {order = array} : tensor<16x128xf16> -> tensor<128x16xf16> + %124 = tt.dot %123, %114, %cst_0 : tensor<128x16xf16> * tensor<16x32xf16> -> tensor<128x32xf32> + %125 = tt.reshape %124 : tensor<128x32xf32> -> tensor<32x4x32xf32> + %126 = tt.atomic_rmw fadd, acq_rel, gpu, %93, %125, %106 : (tensor<32x4x32x!tt.ptr>, tensor<32x4x32xf32>, tensor<32x4x32xi1>) -> tensor<32x4x32xf32> + tt.return + } +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/load_permute_gen_by_make_tensor_ptr.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/load_permute_gen_by_make_tensor_ptr.mlir new file mode 100644 index 000000000..09de85228 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/load_permute_gen_by_make_tensor_ptr.mlir @@ -0,0 +1,161 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-unstructure --bubble-up-operation --discrete-mask-access-conversion --triton-to-hivm "--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False" %s | FileCheck %s + +module { + // CHECK-LABEL: func.func @_attn_fwd + tt.func public @_attn_fwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: !tt.ptr {tt.divisibility = 16 : i32} , %arg4: !tt.ptr {tt.divisibility = 16 : i32} , %arg5: f32 ) attributes {noinline = false} { + %cst = arith.constant dense<1.000000e+00> : tensor<64xf32> + %cst_0 = arith.constant dense<0xFF800000> : tensor<64xf32> + %cst_1 = arith.constant dense<-1.000000e+06> : tensor<64x64xf32> + %cst_2 = arith.constant dense<0.72134751> : tensor<64xf32> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> + %c1_i32 = arith.constant 1 : i32 + %c128_i32 = arith.constant 128 : i32 + %c2048_i32 = arith.constant 2048 : i32 + %cst_4 = arith.constant 1.44269502 : f32 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c64_i32 = arith.constant 64 : i32 + %c131072_i64 = arith.constant 131072 : i64 + %c4194304_i64 = arith.constant 4194304 : i64 + %c32_i32 = arith.constant 32 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %1 : i32 -> tensor<64xi32> + %4 = arith.addi %3, %2 : tensor<64xi32> + %5 = arith.mulf %arg5, %cst_4 : f32 + %6 = tt.splat %5 : f32 -> tensor<64xf32> + %7 = tt.splat %5 : f32 -> tensor<64x64xf32> + %8 = arith.muli %0, %c64_i32 {tt.divisibility = dense<64> : tensor<1xi32>} : i32 + %9 = arith.addi %0, %c1_i32 : i32 + %10 = arith.muli %9, %c64_i32 : i32 + %11 = tt.expand_dims %4 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %12 = tt.expand_dims %2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %13 = tt.broadcast %11 : tensor<64x1xi32> -> tensor<64x64xi32> + %14 = tt.splat %5 : f32 -> tensor<64x64xf32> + %15 = tt.splat %5 : f32 -> tensor<64xf32> + scf.for %arg6 = %c0_i32 to %c128_i32 step %c1_i32 : i32 { + %16 = arith.divsi %arg6, %c32_i32 : i32 + %17 = arith.remsi %arg6, %c32_i32 : i32 + %18 = arith.extsi %16 : i32 to i64 + %19 = arith.muli %18, %c4194304_i64 : i64 + %20 = arith.extsi %17 : i32 to i64 + %21 = arith.muli %20, %c131072_i64 : i64 + %22 = arith.addi %19, %21 : i64 + %23 = tt.addptr %arg0, %22 : !tt.ptr, i64 + %24 = tt.make_tensor_ptr %23, [%c2048_i64, %c64_i64], [%c64_i64, %c1_i64], [%1, %c0_i32] {order = array} : > + // CHECK-NOT: annotation.mark %[[COPYDST0:.*]] {MayImplicitTransposeWithLastAxis} : memref<64x64xf16> + // CHECK-NOT: annotation.mark %[[LOADED0:.*]] {MayImplicitTransposeWithLastAxis} : tensor<64x64xf16> + %25 = tt.addptr %arg2, %22 : !tt.ptr, i64 + %26 = tt.make_tensor_ptr %25, [%c2048_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + // CHECK-NOT: annotation.mark %[[COPYDST1:.*]] {MayImplicitTransposeWithLastAxis} : memref<64x64xf16> + // CHECK-NOT: annotation.mark %[[LOADED1:.*]] {MayImplicitTransposeWithLastAxis} : tensor<64x64xf16> + %27 = tt.addptr %arg1, %22 : !tt.ptr, i64 + %28 = tt.make_tensor_ptr %27, [%c2048_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + // CHECK-NOT: annotation.mark %[[COPYDST2:.*]] {MayImplicitTransposeWithLastAxis} : memref<64x64xf16> + // CHECK-NOT: annotation.mark %[[LOADED2:.*]] {MayImplicitTransposeWithLastAxis} : tensor<64x64xf16> + %29 = tt.addptr %arg4, %22 : !tt.ptr, i64 + %30 = tt.make_tensor_ptr %29, [%c2048_i64, %c64_i64], [%c64_i64, %c1_i64], [%1, %c0_i32] {order = array} : > + // CHECK-NOT: annotation.mark %[[COPYDST3:.*]] {MayImplicitTransposeWithLastAxis} : memref<64x64xf16> + // CHECK-NOT: annotation.mark %[[LOADED3:.*]] {MayImplicitTransposeWithLastAxis} : tensor<64x64xf16> + %31 = tt.load %24 : !tt.ptr> + %32:5 = scf.for %arg7 = %c0_i32 to %1 step %c64_i32 iter_args(%arg8 = %cst, %arg9 = %cst_3, %arg10 = %cst_0, %arg11 = %26, %arg12 = %28) -> (tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %46 = tt.load %arg12 : !tt.ptr> + %47 = tt.trans %46 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> + %48 = tt.dot %31, %47, %cst_3 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> + %49 = "tt.reduce"(%48) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %72 = arith.maxnumf %arg13, %arg14 : f32 + tt.reduce.return %72 : f32 + }) : (tensor<64x64xf32>) -> tensor<64xf32> + %50 = arith.mulf %49, %6 : tensor<64xf32> + %51 = arith.maxnumf %arg10, %50 : tensor<64xf32> + %52 = arith.mulf %48, %7 : tensor<64x64xf32> + %53 = tt.expand_dims %51 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> + %54 = tt.broadcast %53 : tensor<64x1xf32> -> tensor<64x64xf32> + %55 = arith.subf %52, %54 : tensor<64x64xf32> + %56 = math.exp2 %55 : tensor<64x64xf32> + %57 = "tt.reduce"(%56) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %72 = arith.addf %arg13, %arg14 : f32 + tt.reduce.return %72 : f32 + }) : (tensor<64x64xf32>) -> tensor<64xf32> + %58 = arith.subf %arg10, %51 : tensor<64xf32> + %59 = math.exp2 %58 : tensor<64xf32> + %60 = arith.mulf %arg8, %59 : tensor<64xf32> + %61 = arith.addf %60, %57 : tensor<64xf32> + %62 = tt.expand_dims %59 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> + %63 = tt.broadcast %62 : tensor<64x1xf32> -> tensor<64x64xf32> + %64 = arith.mulf %arg9, %63 : tensor<64x64xf32> + %65 = tt.load %arg11 : !tt.ptr> + %66 = arith.truncf %56 : tensor<64x64xf32> to tensor<64x64xf16> + %67 = tt.dot %66, %65, %64 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> + %68 = arith.mulf %51, %6 : tensor<64xf32> + %69 = arith.divf %68, %cst_2 : tensor<64xf32> + %70 = tt.advance %arg11, [%c64_i32, %c0_i32] : > + %71 = tt.advance %arg12, [%c64_i32, %c0_i32] : > + scf.yield %61, %67, %69, %70, %71 : tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, !tt.ptr>, !tt.ptr> + } {tt.divisibility_arg1 = dense<64> : tensor<1xi32>} + %33 = tt.advance %28, [%8, %c0_i32] : > + %34 = tt.advance %26, [%8, %c0_i32] : > + %35:5 = scf.for %arg7 = %8 to %10 step %c64_i32 iter_args(%arg8 = %32#0, %arg9 = %32#1, %arg10 = %32#2, %arg11 = %34, %arg12 = %33) -> (tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %46 = tt.load %arg12 : !tt.ptr> + %47 = tt.trans %46 {order = array} : tensor<64x64xf16> -> tensor<64x64xf16> + %48 = tt.dot %31, %47, %cst_3 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> + %49 = tt.splat %arg7 : i32 -> tensor<1x64xi32> + %50 = arith.addi %49, %12 : tensor<1x64xi32> + %51 = tt.broadcast %50 : tensor<1x64xi32> -> tensor<64x64xi32> + %52 = arith.cmpi sge, %13, %51 : tensor<64x64xi32> + %53 = arith.mulf %48, %14 : tensor<64x64xf32> + %54 = arith.select %52, %cst_3, %cst_1 : tensor<64x64xi1>, tensor<64x64xf32> + %55 = arith.addf %53, %54 : tensor<64x64xf32> + %56 = "tt.reduce"(%55) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %77 = arith.maxnumf %arg13, %arg14 : f32 + tt.reduce.return %77 : f32 + }) : (tensor<64x64xf32>) -> tensor<64xf32> + %57 = arith.maxnumf %arg10, %56 : tensor<64xf32> + %58 = tt.expand_dims %57 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> + %59 = tt.broadcast %58 : tensor<64x1xf32> -> tensor<64x64xf32> + %60 = arith.subf %55, %59 : tensor<64x64xf32> + %61 = math.exp2 %60 : tensor<64x64xf32> + %62 = "tt.reduce"(%61) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %77 = arith.addf %arg13, %arg14 : f32 + tt.reduce.return %77 : f32 + }) : (tensor<64x64xf32>) -> tensor<64xf32> + %63 = arith.subf %arg10, %57 : tensor<64xf32> + %64 = math.exp2 %63 : tensor<64xf32> + %65 = arith.mulf %arg8, %64 : tensor<64xf32> + %66 = arith.addf %65, %62 : tensor<64xf32> + %67 = tt.expand_dims %64 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> + %68 = tt.broadcast %67 : tensor<64x1xf32> -> tensor<64x64xf32> + %69 = arith.mulf %arg9, %68 : tensor<64x64xf32> + %70 = tt.load %arg11 : !tt.ptr> + %71 = arith.truncf %61 : tensor<64x64xf32> to tensor<64x64xf16> + %72 = tt.dot %71, %70, %69 : tensor<64x64xf16> * tensor<64x64xf16> -> tensor<64x64xf32> + %73 = arith.mulf %57, %15 : tensor<64xf32> + %74 = arith.divf %73, %cst_2 : tensor<64xf32> + %75 = tt.advance %arg11, [%c64_i32, %c0_i32] : > + %76 = tt.advance %arg12, [%c64_i32, %c0_i32] : > + scf.yield %66, %72, %74, %75, %76 : tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>, !tt.ptr>, !tt.ptr> + } {tt.divisibility_arg1 = dense<64> : tensor<1xi32>} + %36 = math.log2 %35#0 : tensor<64xf32> + %37 = arith.addf %35#2, %36 : tensor<64xf32> + %38 = tt.expand_dims %35#0 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> + %39 = tt.broadcast %38 : tensor<64x1xf32> -> tensor<64x64xf32> + %40 = arith.divf %35#1, %39 : tensor<64x64xf32> + %41 = arith.muli %arg6, %c2048_i32 : i32 + %42 = tt.addptr %arg3, %41 : !tt.ptr, i32 + %43 = tt.splat %42 : !tt.ptr -> tensor<64x!tt.ptr> + %44 = tt.addptr %43, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %44, %37 : tensor<64x!tt.ptr> + %45 = arith.truncf %40 : tensor<64x64xf32> to tensor<64x64xf16> + tt.store %30, %45 : !tt.ptr> + } + tt.return + } +} + diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/loadstorecanonicalizer.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/loadstorecanonicalizer.mlir new file mode 100644 index 000000000..1ec3a65c9 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/loadstorecanonicalizer.mlir @@ -0,0 +1,24 @@ +// RUN: triton-adapter-opt --triton-to-linalg="named-ops=True" %s | FileCheck %s + +// CHECK-LABEL: func @loadstorecanonicalizer_simple( +tt.func public @loadstorecanonicalizer_simple(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 + %3 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %4 = tt.splat %3 : !tt.ptr -> tensor<1x!tt.ptr> + // CHECK: %[[CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] + // CHECK: memref.copy %[[CAST0]], %[[ALLOC0:.*]] + %5 = tt.load %4 : tensor<1x!tt.ptr> + %6 = tt.addptr %arg1, %0 : !tt.ptr, i32 + %7 = tt.splat %6 : !tt.ptr -> tensor<1x!tt.ptr> + // CHECK: %[[CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] + // CHECK: memref.copy %[[CAST1]], %[[ALLOC1:.*]] + %8 = tt.load %7 : tensor<1x!tt.ptr> + %9 = arith.addf %5, %8 : tensor<1xf32> + %10 = tt.addptr %arg2, %0 : !tt.ptr, i32 + %11 = tt.splat %10 : !tt.ptr -> tensor<1x!tt.ptr> + // CHECK: %[[CAST2:.*]] = memref.reinterpret_cast %[[ARG2:.*]] to offset + // CHECK: bufferization.materialize_in_destination %[[VAL:.*]] in writable %[[CAST2]] + tt.store %11, %9 : tensor<1x!tt.ptr> + tt.return +} + diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir new file mode 100644 index 000000000..af4b36b5d --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir @@ -0,0 +1,97 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s + +module { + tt.func @kernel_low_mask( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %ldptr = tt.addptr %0, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %stptr = tt.addptr %1, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %nans = arith.constant dense<0xFF80> : tensor<128xbf16> + %5 = tt.splat %arg2 : i32 -> tensor<128xi32> + %mask = arith.cmpi slt, %2, %5 : tensor<128xi32> + %buff = tt.load %ldptr, %mask, %nans : tensor<128x!tt.ptr> + tt.store %stptr, %buff, %mask : tensor<128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel_low_mask( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { + +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0xFF80 : bf16 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128], strides: [1] : memref to memref<128xbf16, strided<[1]>> +// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128], strides: [1] : memref to memref<128xbf16, strided<[1]>> +// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<128xbf16> +// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_2]] : i32 to index +// CHECK: %[[VAL_12_0:.*]] = arith.maxsi %[[VAL_11]], %[[VAL_5]] : index +// CHECK: %[[VAL_12:.*]] = arith.minsi %[[VAL_12_0]], %[[VAL_7]] : index +// CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_12]], %[[VAL_7]] : index +// CHECK: scf.if %[[VAL_15]] { +// CHECK: linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : memref<128xbf16>) +// CHECK: } +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16> to memref> +// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref> to memref> +// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<128xbf16> + +// CHECK: %[[VAL_19:.*]] = tensor.extract_slice %[[VAL_16]][0] {{\[}}%[[VAL_12]]] [1] : tensor<128xbf16> to tensor +// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_9]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_20]] +// CHECK: return +// CHECK: } + +// ----- + +module { + tt.func @kernel_high_mask( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %ldptr = tt.addptr %0, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %stptr = tt.addptr %1, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %nans = arith.constant dense<0xFF80> : tensor<128xbf16> + %5 = tt.splat %arg2 : i32 -> tensor<128xi32> + %mask = arith.cmpi sge, %2, %5 : tensor<128xi32> + %buff = tt.load %ldptr, %mask, %nans : tensor<128x!tt.ptr> + tt.store %stptr, %buff, %mask : tensor<128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel_high_mask( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, %[[PA_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[PA_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %[[PA_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[CONST_NAN:.*]] = arith.constant 0xFF80 : bf16 +// CHECK-DAG: %[[CONST_128:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[CONST_0:.*]] = arith.constant 0 : index +// CHECK: %[[CAST_IN:.*]] = memref.reinterpret_cast %[[PA_0]] to offset: [0], sizes: [128], strides: [1] : memref to memref<128xbf16, strided<[1]>> +// CHECK: %[[CAST_OUT:.*]] = memref.reinterpret_cast %[[PA_1]] to offset: [0], sizes: [128], strides: [1] : memref to memref<128xbf16, strided<[1]>> +// CHECK: %[[BUFF:.*]] = memref.alloc() : memref<128xbf16> +// CHECK: %[[VAL_0:.*]] = arith.index_cast %[[PA_2]] : i32 to index +// CHECK: %[[VAL_1_0:.*]] = arith.maxsi %[[VAL_0]], %[[CONST_0]] : index +// CHECK: %[[VAL_1:.*]] = arith.minsi %[[VAL_1_0]], %[[CONST_128]] : index +// CHECK: %[[VAL_2:.*]] = arith.subi %[[CONST_128]], %[[VAL_1]] : index +// CHECK: %[[VAL_3:.*]] = arith.cmpi slt, %[[VAL_2]], %[[CONST_128]] : index +// CHECK: scf.if %[[VAL_3]] { +// CHECK: linalg.fill ins(%[[CONST_NAN]] : bf16) outs(%[[BUFF]] : memref<128xbf16>) +// CHECK: } +// CHECK: %[[SUB_IN:.*]] = memref.subview %[[CAST_IN]]{{\[}}%[[VAL_1]]] {{\[}}%[[VAL_2]]] [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: %[[SUB_BUFF:.*]] = memref.subview %[[BUFF]]{{\[}}%[[VAL_1]]] {{\[}}%[[VAL_2]]] [1] : memref<128xbf16> to memref> +// CHECK: memref.copy %[[SUB_IN]], %[[SUB_BUFF]] : memref> to memref> +// CHECK: %[[VAL_4:.*]] = bufferization.to_tensor %[[BUFF]] restrict writable : memref<128xbf16> +// CHECK: %[[SLICE_BUFF:.*]] = tensor.extract_slice %[[VAL_4]]{{\[}}%[[VAL_1]]] {{\[}}%[[VAL_2]]] [1] : tensor<128xbf16> to tensor +// CHECK: %[[SUB_OUT:.*]] = memref.subview %[[CAST_OUT]]{{\[}}%[[VAL_1]]] {{\[}}%[[VAL_2]]] [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: bufferization.materialize_in_destination %[[SLICE_BUFF]] in writable %[[SUB_OUT]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_2d_high_mask.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_2d_high_mask.mlir new file mode 100644 index 000000000..8179b9212 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_2d_high_mask.mlir @@ -0,0 +1,144 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg --split-input-file %s | FileCheck %s + +module { + tt.func @kernel_high_mask( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32, + %arg3 : i32, + %arg4 : i32, + %arg5 : i32 + ) + { + // Mimic a scenario where the raw pointer points to a buffer with dimension (1024, 1024) + // in row-major, but the actual tensor size is (arg2, arg3). + // We are trying to load a 128x256 sub-buffer starting at (2, 3). + // The resulting memref: + // offset = 3074 + // size[1] = 128 + // size[0] = 256 + // stride[0] = 1024 + // stride[1] = 1 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x256x!tt.ptr> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x256x!tt.ptr> + // horizontal index + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %c2 = arith.constant 2 : i32 + %c2tensor = tt.splat %c2 : i32 -> tensor<128xi32> + %offset2 = arith.addi %2, %c2tensor : tensor<128xi32> + %3 = tt.expand_dims %offset2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %4 = tt.broadcast %3 : tensor<128x1xi32> -> tensor<128x256xi32> + // vertical index + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %c3 = arith.constant 3 : i32 + %c3tensor = tt.splat %c3 : i32 -> tensor<256xi32> + %offset5 = arith.addi %5, %c3tensor : tensor<256xi32> + %c1024 = arith.constant 1024 : i32 + %c1024tensor = tt.splat %c1024 : i32 -> tensor<256xi32> + %scale5 = arith.muli %offset5, %c1024tensor : tensor<256xi32> + %6 = tt.expand_dims %scale5 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %7 = tt.broadcast %6 : tensor<1x256xi32> -> tensor<128x256xi32> + // combined index + %index = arith.addi %4, %7 : tensor<128x256xi32> + %ldptr = tt.addptr %0, %index : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %stptr = tt.addptr %1, %index : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + // other value for masked load + %cnan = arith.constant 0xFF80 : bf16 + %nans = tt.splat %cnan : bf16 -> tensor<128x256xbf16> + // horizontal mask + %8 = tt.splat %arg2 : i32 -> tensor<128xi32> + %9 = arith.cmpi sge, %offset2, %8 : tensor<128xi32> + %10 = tt.splat %arg3 : i32 -> tensor<128xi32> + %11 = arith.cmpi slt, %offset2, %10 : tensor<128xi32> + %12 = arith.andi %9, %11 : tensor<128xi1> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<128xi1> -> tensor<128x1xi1> + %14 = tt.broadcast %13 : tensor<128x1xi1> -> tensor<128x256xi1> + // vertical mask + %15 = tt.splat %arg4 : i32 -> tensor<256xi32> + %16 = arith.cmpi sge, %offset5, %15 : tensor<256xi32> + %17 = tt.splat %arg5 : i32 -> tensor<256xi32> + %18 = arith.cmpi slt, %offset5, %17 : tensor<256xi32> + %19 = arith.andi %16, %18 : tensor<256xi1> + %20 = tt.expand_dims %19 {axis = 0 : i32} : tensor<256xi1> -> tensor<1x256xi1> + %21 = tt.broadcast %20 : tensor<1x256xi1> -> tensor<128x256xi1> + // combined mask + %mask = arith.andi %14, %21 : tensor<128x256xi1> + // offset0 = max(%arg2-2, 0), dim0 = min(%arg3-2, 128) - offset0 + // offset1 = max(%arg4-3, 0), dim1 = min(%arg5-3, 256) - offset1 + // TODO: need reinterpret cast + %buff = tt.load %ldptr, %mask, %nans : tensor<128x256x!tt.ptr> + tt.store %stptr, %buff, %mask : tensor<128x256x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel_high_mask( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32, %[[ARG_11:.*]]: i32, %[[ARG_12:.*]]: i32, %[[ARG_13:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 0xFF80 : bf16 +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 256 : index +// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 259 : index +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 130 : index +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 2 : index + + +// CHECK: %[[VAL_16:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [3074], sizes: [128, 256], strides: [1, 1024] : memref to memref<128x256xbf16, strided<[1, 1024], offset: 3074>> +// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [3074], sizes: [128, 256], strides: [1, 1024] : memref to memref<128x256xbf16, strided<[1, 1024], offset: 3074>> +// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<128x256xbf16> + +// CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_2]] : i32 to index +// CHECK: %[[VAL_20_0:.*]] = arith.maxsi %[[VAL_19]], %[[VAL_10]] : index +// CHECK: %[[VAL_20:.*]] = arith.minsi %[[VAL_20_0]], %[[VAL_14]] : index +// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_20]], %[[VAL_10]] : index +// CHECK: %[[VAL_22:.*]] = arith.subi %[[VAL_14]], %[[VAL_20]] : index + +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_3]] : i32 to index +// CHECK: %[[VAL_24_0:.*]] = arith.maxsi %[[VAL_23]], %[[VAL_10]] : index +// CHECK: %[[VAL_24:.*]] = arith.minsi %[[VAL_24_0]], %[[VAL_14]] : index +// CHECK: %[[VAL_25:.*]] = arith.subi %[[VAL_24]], %[[VAL_10]] : index + +// CHECK: %[[VAL_26:.*]] = arith.maxsi %[[VAL_21]], %[[VAL_6]] : index +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_21]], %[[VAL_22]] : index +// CHECK: %[[VAL_28:.*]] = arith.minsi %[[VAL_27]], %[[VAL_25]] : index + +// CHECK: %[[VAL_29:.*]] = arith.index_cast %[[VAL_4]] : i32 to index +// CHECK: %[[VAL_30_0:.*]] = arith.maxsi %[[VAL_29]], %[[VAL_9]] : index +// CHECK: %[[VAL_30:.*]] = arith.minsi %[[VAL_30_0]], %[[VAL_13]] : index +// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_30]], %[[VAL_9]] : index +// CHECK: %[[VAL_32:.*]] = arith.subi %[[VAL_13]], %[[VAL_30]] : index + +// CHECK: %[[VAL_33:.*]] = arith.index_cast %[[VAL_5]] : i32 to index +// CHECK: %[[VAL_34_0:.*]] = arith.maxsi %[[VAL_33]], %[[VAL_9]] : index +// CHECK: %[[VAL_34:.*]] = arith.minsi %[[VAL_34_0]], %[[VAL_13]] : index +// CHECK: %[[VAL_35:.*]] = arith.subi %[[VAL_34]], %[[VAL_9]] : index + +// CHECK: %[[VAL_36:.*]] = arith.maxsi %[[VAL_31]], %[[VAL_6]] : index +// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_31]], %[[VAL_32]] : index +// CHECK: %[[VAL_38:.*]] = arith.minsi %[[VAL_37]], %[[VAL_35]] : index + +// CHECK: %[[VAL_39:.*]] = arith.maxsi %[[VAL_26]], %[[VAL_6]] : index +// CHECK: %[[VAL_40:.*]] = arith.minsi %[[VAL_28]], %[[VAL_12]] : index +// CHECK: %[[VAL_41:.*]] = arith.subi %[[VAL_40]], %[[VAL_39]] : index + +// CHECK: %[[VAL_42:.*]] = arith.maxsi %[[VAL_36]], %[[VAL_6]] : index +// CHECK: %[[VAL_43:.*]] = arith.minsi %[[VAL_38]], %[[VAL_11]] : index +// CHECK: %[[VAL_44:.*]] = arith.subi %[[VAL_43]], %[[VAL_42]] : index + +// CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_41]], %[[VAL_12]] : index +// CHECK: %[[VAL_46:.*]] = arith.cmpi slt, %[[VAL_44]], %[[VAL_11]] : index +// CHECK: %[[VAL_47:.*]] = arith.ori %[[VAL_45]], %[[VAL_46]] : i1 +// CHECK: scf.if %[[VAL_47]] { +// CHECK: linalg.fill ins(%[[VAL_15]] : bf16) outs(%[[VAL_18]] : memref<128x256xbf16>) +// CHECK: } +// CHECK: %[[VAL_48:.*]] = memref.subview %[[VAL_16]]{{\[}}%[[VAL_39]], %[[VAL_42]]] {{\[}}%[[VAL_41]], %[[VAL_44]]] [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: 3074>> to memref> +// CHECK: %[[VAL_49:.*]] = memref.subview %[[VAL_18]]{{\[}}%[[VAL_39]], %[[VAL_42]]] {{\[}}%[[VAL_41]], %[[VAL_44]]] [1, 1] : memref<128x256xbf16> to memref> +// CHECK: memref.copy %[[VAL_48]], %[[VAL_49]] : memref> to memref> +// CHECK: %[[VAL_50:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<128x256xbf16> +// CHECK: %[[VAL_77:.*]] = tensor.extract_slice %[[VAL_50]]{{\[}}%[[VAL_39]], %[[VAL_42]]] {{\[}}%[[VAL_41]], %[[VAL_44]]] [1, 1] : tensor<128x256xbf16> to tensor +// CHECK: %[[VAL_78:.*]] = memref.subview %[[VAL_17]]{{\[}}%[[VAL_39]], %[[VAL_42]]] {{\[}}%[[VAL_41]], %[[VAL_44]]] [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: 3074>> to memref> +// CHECK: bufferization.materialize_in_destination %[[VAL_77]] in writable %[[VAL_78]] : (tensor, memref>) -> () +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_2d_low_mask.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_2d_low_mask.mlir new file mode 100644 index 000000000..2a43310d9 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_2d_low_mask.mlir @@ -0,0 +1,103 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg --split-input-file %s | FileCheck %s + +module { + tt.func @kernel_low_mask( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32, + %arg3 : i32 + ) + { + // Mimic a scenario where the raw pointer points to a buffer with dimension (1024, 1024) + // in row-major, but the actual tensor size is (arg2, arg3). + // We are trying to load a 128x256 sub-buffer starting at (2, 3). + // The resulting memref: + // offset = 3074 + // size[1] = 128 + // size[0] = 256 + // stride[0] = 1024 + // stride[1] = 1 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x256x!tt.ptr> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x256x!tt.ptr> + // horizontal index + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %c2 = arith.constant 2 : i32 + %c2tensor = tt.splat %c2 : i32 -> tensor<128xi32> + %offset2 = arith.addi %2, %c2tensor : tensor<128xi32> + %3 = tt.expand_dims %offset2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %4 = tt.broadcast %3 : tensor<128x1xi32> -> tensor<128x256xi32> + // vertical index + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %c3 = arith.constant 3 : i32 + %c3tensor = tt.splat %c3 : i32 -> tensor<256xi32> + %offset5 = arith.addi %5, %c3tensor : tensor<256xi32> + %c1024 = arith.constant 1024 : i32 + %c1024tensor = tt.splat %c1024 : i32 -> tensor<256xi32> + %scale5 = arith.muli %offset5, %c1024tensor : tensor<256xi32> + %6 = tt.expand_dims %scale5 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %7 = tt.broadcast %6 : tensor<1x256xi32> -> tensor<128x256xi32> + // combined index + %index = arith.addi %4, %7 : tensor<128x256xi32> + %ldptr = tt.addptr %0, %index : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %stptr = tt.addptr %1, %index : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + // other value for masked load + %cnan = arith.constant 0xFF80 : bf16 + %nans = tt.splat %cnan : bf16 -> tensor<128x256xbf16> + // horizontal mask + %8 = tt.splat %arg2 : i32 -> tensor<128xi32> + %9 = arith.cmpi slt, %offset2, %8 : tensor<128xi32> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi1> -> tensor<128x1xi1> + %11 = tt.broadcast %10 : tensor<128x1xi1> -> tensor<128x256xi1> + // vertical mask + %12 = tt.splat %arg3 : i32 -> tensor<256xi32> + %13 = arith.cmpi slt, %offset5, %12 : tensor<256xi32> + %14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<256xi1> -> tensor<1x256xi1> + %15 = tt.broadcast %14 : tensor<1x256xi1> -> tensor<128x256xi1> + // combined mask + %mask = arith.andi %11, %15 : tensor<128x256xi1> + // dim0 = min(%arg2-2, 128), dim1 = min(%arg3-3, 256) + // TODO: need reinterpret cast + %buff = tt.load %ldptr, %mask, %nans : tensor<128x256x!tt.ptr> + tt.store %stptr, %buff, %mask : tensor<128x256x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel_low_mask( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 0xFF80 : bf16 +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 256 : index +// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 259 : index +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 130 : index +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 2 : index + +// CHECK: %[[VAL_16:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [3074], sizes: [128, 256], strides: [1, 1024] : memref to memref<128x256xbf16, strided<[1, 1024], offset: 3074>> +// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [3074], sizes: [128, 256], strides: [1, 1024] : memref to memref<128x256xbf16, strided<[1, 1024], offset: 3074>> +// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<128x256xbf16> +// CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_2]] : i32 to index +// CHECK: %[[VAL_20_0:.*]] = arith.maxsi %[[VAL_19]], %[[VAL_10]] : index +// CHECK: %[[VAL_20:.*]] = arith.minsi %[[VAL_20_0]], %[[VAL_14]] : index +// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_20]], %[[VAL_10]] : index +// CHECK: %[[VAL_22:.*]] = arith.index_cast %[[VAL_3]] : i32 to index +// CHECK: %[[VAL_23_0:.*]] = arith.maxsi %[[VAL_22]], %[[VAL_9]] : index +// CHECK: %[[VAL_23:.*]] = arith.minsi %[[VAL_23_0]], %[[VAL_13]] : index +// CHECK: %[[VAL_24:.*]] = arith.subi %[[VAL_23]], %[[VAL_9]] : index +// CHECK: %[[VAL_25:.*]] = arith.minsi %[[VAL_21]], %[[VAL_12]] : index +// CHECK: %[[VAL_26:.*]] = arith.minsi %[[VAL_24]], %[[VAL_11]] : index +// CHECK: %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_12]] : index +// CHECK: %[[VAL_30:.*]] = arith.cmpi slt, %[[VAL_26]], %[[VAL_11]] : index +// CHECK: %[[VAL_31:.*]] = arith.ori %[[VAL_29]], %[[VAL_30]] : i1 +// CHECK: scf.if %[[VAL_31]] { +// CHECK: linalg.fill ins(%[[VAL_15]] : bf16) outs(%[[VAL_18]] : memref<128x256xbf16>) +// CHECK: } +// CHECK: %[[VAL_27:.*]] = memref.subview %[[VAL_16]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: 3074>> to memref> +// CHECK: %[[VAL_28:.*]] = memref.subview %[[VAL_18]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : memref<128x256xbf16> to memref> +// CHECK: memref.copy %[[VAL_27]], %[[VAL_28]] : memref> to memref> +// CHECK: %[[VAL_32:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<128x256xbf16> +// CHECK: %[[VAL_41:.*]] = tensor.extract_slice %[[VAL_32]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : tensor<128x256xbf16> to tensor +// CHECK: %[[VAL_42:.*]] = memref.subview %[[VAL_17]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : memref<128x256xbf16, strided<[1, 1024], offset: 3074>> to memref> +// CHECK: bufferization.materialize_in_destination %[[VAL_41]] in writable %[[VAL_42]] +// CHECK: return +// CHECK: } \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir new file mode 100644 index 000000000..403292379 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir @@ -0,0 +1,51 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : i32 + ) + { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %ldptr = tt.addptr %0, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %stptr = tt.addptr %1, %2 : tensor<128x!tt.ptr>, tensor<128xi32> + %c7_i32 = arith.constant 7 : i32 + %splat_c7_i32 = tt.splat %c7_i32 : i32 -> tensor<128xi32> + %splat_c7_bf16 = arith.sitofp %splat_c7_i32 : tensor<128xi32> to tensor<128xbf16> + %5 = tt.splat %arg2 : i32 -> tensor<128xi32> + %mask = arith.cmpi slt, %2, %5 : tensor<128xi32> + %buff = tt.load %ldptr, %mask, %splat_c7_bf16 : tensor<128x!tt.ptr> + tt.store %stptr, %buff, %mask : tensor<128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) +// CHECK-SAME: attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 7.000000e+00 : bf16 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index + +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128], strides: [1] : memref to memref<128xbf16, strided<[1]>> +// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128], strides: [1] : memref to memref<128xbf16, strided<[1]>> +// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<128xbf16> +// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_2]] : i32 to index +// CHECK: %[[VAL_12_0:.*]] = arith.maxsi %[[VAL_11]], %[[VAL_5]] : index +// CHECK: %[[VAL_12:.*]] = arith.minsi %[[VAL_12_0]], %[[VAL_7]] : index +// CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_12]], %[[VAL_7]] : index +// CHECK: scf.if %[[VAL_15]] { +// CHECK: linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : memref<128xbf16>) +// CHECK: } +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16> to memref> +// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref> to memref> +// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<128xbf16> + +// CHECK: %[[VAL_19:.*]] = tensor.extract_slice %[[VAL_16]][0] {{\[}}%[[VAL_12]]] [1] : tensor<128xbf16> to tensor +// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_9]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_20]] : (tensor, memref>) -> () +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_splat_div.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_splat_div.mlir new file mode 100644 index 000000000..4005bbd31 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/masked_ldst_splat_div.mlir @@ -0,0 +1,108 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s + +module { + // CHECK-LABEL: func.func @triton_splat_as_mask + tt.func public @triton_splat_as_mask(%arg0: !tt.ptr, %arg1: !tt.ptr) { + %c8_i32 = arith.constant 8 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> + %cst_0 = arith.constant dense<8> : tensor<1x8xi32> + %1 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %rmask = arith.cmpi slt, %2, %cst_0 : tensor<1x8xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1x8x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<1x8x!tt.ptr> + scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 : i32 { + %5 = arith.cmpi slt, %arg2, %c8_i32 : i32 + %6 = tt.addptr %3, %2 : tensor<1x8x!tt.ptr>, tensor<1x8xi32> + %xmask = tt.splat %5 : i1 -> tensor<1x8xi1> + %mask = arith.andi %rmask, %xmask : tensor<1x8xi1> + // CHECK-DAG: %[[input_slice:.*]] = memref.subview + // CHECK-DAG: %[[buffer_slice:.*]] = memref.subview + // CHECK: memref.copy %[[input_slice]], %[[buffer_slice]] + %7 = tt.load %6, %mask, %cst : tensor<1x8x!tt.ptr> + %8 = tt.addptr %4, %2 : tensor<1x8x!tt.ptr>, tensor<1x8xi32> + // CHECK-DAG: %[[buffer_tensor_slice:.*]] = tensor.extract_slice + // CHECK-DAG: %[[output_slice:.*]] = memref.subview + // CHECK: bufferization.materialize_in_destination %[[buffer_tensor_slice]] in writable %[[output_slice]] + tt.store %8, %7, %mask : tensor<1x8x!tt.ptr> + } + tt.return + } +} + +// ----- + +module { + // CHECK-LABEL: func.func @triton_divide_as_mask + tt.func public @triton_divide_as_mask(%arg0: !tt.ptr, + %arg1: !tt.ptr, + %arg2: i32) { + %c8_i32 = arith.constant 8 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> + %cst_0 = arith.constant dense<8> : tensor<1x8xi32> + %cst_1 = arith.constant dense<1> : tensor<1x8xi32> + %1 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %3 = tt.splat %arg2 : i32 -> tensor<1x8xi32> + %4 = arith.divsi %2, %3 : tensor<1x8xi32> + %rmask = arith.cmpi slt, %4, %cst_1 : tensor<1x8xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1x8x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<1x8x!tt.ptr> + scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 : i32 { + %7 = arith.cmpi slt, %arg3, %c8_i32 : i32 + %8 = tt.addptr %5, %2 : tensor<1x8x!tt.ptr>, tensor<1x8xi32> + %xmask = tt.splat %7 : i1 -> tensor<1x8xi1> + %mask = arith.andi %rmask, %xmask : tensor<1x8xi1> + // CHECK-DAG: %[[input_slice:.*]] = memref.subview + // CHECK-DAG: %[[buffer_slice:.*]] = memref.subview + // CHECK: memref.copy %[[input_slice]], %[[buffer_slice]] + %9 = tt.load %8, %mask, %cst : tensor<1x8x!tt.ptr> + %10 = tt.addptr %6, %2 : tensor<1x8x!tt.ptr>, tensor<1x8xi32> + // CHECK-DAG: %[[buffer_tensor_slice:.*]] = tensor.extract_slice + // CHECK-DAG: %[[output_slice:.*]] = memref.subview + // CHECK: bufferization.materialize_in_destination %[[buffer_tensor_slice]] in writable %[[output_slice]] + tt.store %10, %9, %mask : tensor<1x8x!tt.ptr> + } + tt.return + } +} + +// ----- + +module { + // CHECK-LABEL: func.func @triton_bool_splat_condition_select + tt.func public @triton_bool_splat_condition_select(%arg0: !tt.ptr, + %arg1: !tt.ptr, + %arg2: !tt.ptr, + %arg3: i32) + attributes {noinline = false} { + // CHECK: %[[baseline:.*]] = arith.constant 0 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %6 = tt.load %5 : tensor<32x!tt.ptr> + // CHECK:%[[VAL_2:.*]] = arith.cmpi eq, [[ARG_3:%.+]], %[[baseline]] : i32 + // CHECK:%[[VAL_3:.*]] = tensor.empty() : tensor<32xi1> + // CHECK:%[[VAL_4:.*]] = linalg.fill ins([[VAL_2]] : i1) outs(%[[VAL_3]] : tensor<32xi1>) -> tensor<32xi1> + // CHECK:%5 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_4]], %0, %1 : tensor<32xi1>, tensor<32xf32>, tensor<32xf32>) outs(%0 : tensor<32xf32>) { + // CHECK:^bb0(%in: i1, %in_3: f32, %in_4: f32, %out: f32): + // CHECK:%[[VAL_6:.*]] = arith.select %in, %in_3, %in_4 : f32 + // CHECK:linalg.yield %[[VAL_6]] : f32 + // CHECK:} -> tensor<32xf32> + %7 = arith.cmpi eq, %arg3, %c0_i32 : i32 + %8 = tt.splat %7 : i1 -> tensor<32xi1> + %9 = arith.select %8, %3, %6 : tensor<32xi1>, tensor<32xf32> + %10 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %11 = tt.addptr %10, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %11, %9 : tensor<32x!tt.ptr> + tt.return +} +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/max_value_index.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/max_value_index.mlir new file mode 100644 index 000000000..679608e9d --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/max_value_index.mlir @@ -0,0 +1,126 @@ +// RUN: triton-adapter-opt --triton-to-linalg --split-input-file %s | FileCheck %s + +module { + tt.func public @triton_per_fused_0d1d2d345678910111213(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: !tt.ptr , %arg4: !tt.ptr , %arg5: !tt.ptr , %arg6: !tt.ptr , %arg7: !tt.ptr , %arg8: !tt.ptr , %arg9: !tt.ptr , %arg10: !tt.ptr , %arg11: !tt.ptr , %arg12: i32 , %arg13: i32 ) attributes {noinline = false} { + %cst = arith.constant dense<0xFF800000> : tensor<1x4xf32> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<1x4xf32> + %true = arith.constant true + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %cst_1 = arith.constant dense<4> : tensor<1x4xi32> + %c0_i32 = arith.constant 0 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c4096_i32 : i32 + %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %3 = tt.expand_dims %2 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %4 = arith.cmpi slt, %3, %cst_1 : tensor<1x4xi32> + %5 = arith.select %4, %cst_0, %cst : tensor<1x4xi1>, tensor<1x4xf32> + %6 = tt.broadcast %5 : tensor<1x4xf32> -> tensor<4096x4xf32> + %7 = tt.broadcast %3 : tensor<1x4xi32> -> tensor<4096x4xi32> + %8:2 = "tt.reduce"(%6, %7) <{axis = 1 : i32}> ({ + ^bb0(%arg14: f32 , %arg15: i32 , %arg16: f32 , %arg17: i32 ): + %12 = arith.cmpf ogt, %arg14, %arg16 : f32 + %13 = arith.cmpf oeq, %arg14, %arg16 : f32 + %14 = arith.cmpf une, %arg14, %arg14 : f32 + %15 = arith.cmpf une, %arg16, %arg16 : f32 + %16 = arith.xori %15, %true : i1 + %17 = arith.andi %14, %16 : i1 + %18 = arith.ori %12, %17 : i1 + %19 = arith.andi %14, %15 : i1 + %20 = arith.ori %13, %19 : i1 + %21 = arith.cmpi slt, %arg15, %arg17 : i32 + %22 = arith.andi %20, %21 : i1 + %23 = arith.ori %18, %22 : i1 + %24 = arith.select %23, %arg14, %arg16 : f32 + %25 = arith.select %23, %arg15, %arg17 : i32 + tt.reduce.return %24, %25 : f32, i32 + }) : (tensor<4096x4xf32>, tensor<4096x4xi32>) -> (tensor<4096xf32>, tensor<4096xi32>) + %9 = tt.expand_dims %8#0 {axis = 1 : i32} : tensor<4096xf32> -> tensor<4096x1xf32> + %10 = tt.make_tensor_ptr %arg5, [%c4096_i64, %c4_i64], [%c4_i64, %c1_i64], [%1, %c0_i32] {order = array} : > + %11 = tt.broadcast %9 : tensor<4096x1xf32> -> tensor<4096x4xf32> + tt.store %10, %11 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + + +// CHECK: %[[VAL_89:.*]] = linalg.reduce ins(%[[VAL_6:.*]], %[[VAL_8:.*]] : tensor<4096x4xf32>, tensor<4096x4xi32>) outs(%[[VAL_9:.*]], %[[VAL_10:.*]] : tensor<4096xf32>, tensor<4096xi32>) dimensions = [1] +// CHECK: (%[[VAL_IN:.*]]: f32, %[[VAL_150:.*]]: i32, %[[VAL_INIT:.*]]: f32, %[[VAL_151:.*]]: i32) { +// CHECK: %[[VAL_264:.*]] = arith.cmpf ogt, %[[VAL_IN]], %[[VAL_INIT]] : f32 +// CHECK: %[[VAL_265:.*]] = arith.cmpf oeq, %[[VAL_IN]], %[[VAL_INIT]] : f32 +// CHECK: %[[VAL_266:.*]] = arith.cmpf une, %[[VAL_IN]], %[[VAL_IN]] : f32 +// CHECK: %[[VAL_267:.*]] = arith.cmpf une, %[[VAL_INIT]], %[[VAL_INIT]] : f32 +// CHECK: %[[VAL_268:.*]] = arith.xori %[[VAL_267]], %true : i1 +// CHECK: %[[VAL_269:.*]] = arith.andi %[[VAL_266]], %[[VAL_268]] : i1 +// CHECK: %[[VAL_270:.*]] = arith.ori %[[VAL_264]], %[[VAL_269]] : i1 +// CHECK: %[[VAL_271:.*]] = arith.andi %[[VAL_266]], %[[VAL_267]] : i1 +// CHECK: %[[VAL_272:.*]] = arith.ori %[[VAL_265]], %[[VAL_271]] : i1 +// CHECK: %[[VAL_273:.*]] = arith.cmpi slt, %[[VAL_150]], %[[VAL_151]] : i32 +// CHECK: %[[VAL_274:.*]] = arith.andi %[[VAL_272]], %[[VAL_273]] : i1 +// CHECK: %[[VAL_275:.*]] = arith.ori %[[VAL_270]], %[[VAL_274]] : i1 +// CHECK: %[[VAL_276:.*]] = arith.select %[[VAL_275]], %[[VAL_IN]], %[[VAL_INIT]] : f32 +// CHECK: %[[VAL_277:.*]] = arith.select %[[VAL_275]], %[[VAL_150]], %[[VAL_151]] : i32 +// CHECK: linalg.yield %[[VAL_276]], %[[VAL_277]] : f32, i32 +// CHECK: } + +// ----- + +module { + tt.func public @triton_test_fn_min_with_index_inner_scalar_0d1d2d345(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} , %arg3: !tt.ptr , %arg4: i32 , %arg5: i32 ) attributes {noinline = false} { + %true = arith.constant true + %c2_i32 = arith.constant 2 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c2_i32 : i32 + %2 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %3 = tt.splat %1 : i32 -> tensor<2xi32> + %4 = arith.addi %3, %2 : tensor<2xi32> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<2x!tt.ptr>, tensor<2xi32> + %7 = tt.load %6 : tensor<2x!tt.ptr> + %8 = tt.splat %arg3 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %4 : tensor<2x!tt.ptr>, tensor<2xi32> + %10 = tt.load %9 : tensor<2x!tt.ptr> + %11:2 = "tt.reduce"(%7, %10) <{axis = 0 : i32}> ({ + ^bb0(%arg6: f16 , %arg7: i32 , %arg8: f16 , %arg9: i32 ): + %12 = arith.cmpf olt, %arg6, %arg8 : f16 + %13 = arith.cmpf oeq, %arg6, %arg8 : f16 + %14 = arith.cmpf une, %arg6, %arg6 : f16 + %15 = arith.cmpf une, %arg8, %arg8 : f16 + %16 = arith.xori %15, %true : i1 + %17 = arith.andi %14, %16 : i1 + %18 = arith.ori %12, %17 : i1 + %19 = arith.andi %14, %15 : i1 + %20 = arith.ori %13, %19 : i1 + %21 = arith.cmpi slt, %arg7, %arg9 : i32 + %22 = arith.andi %20, %21 : i1 + %23 = arith.ori %18, %22 : i1 + %24 = arith.select %23, %arg6, %arg8 : f16 + %25 = arith.select %23, %arg7, %arg9 : i32 + tt.reduce.return %24, %25 : f16, i32 + }) : (tensor<2xf16>, tensor<2xi32>) -> (f16, i32) + tt.store %arg0, %11#0 : !tt.ptr + tt.store %arg1, %11#1 : !tt.ptr + tt.return + } +} + +// CHECK: %[[VAL_89:.*]] = linalg.reduce ins(%[[VAL_2:.*]], %[[VAL_4:.*]] : tensor<2xf16>, tensor<2xi32>) outs(%[[VAL_5:.*]], %[[VAL_6:.*]] : tensor, tensor) dimensions = [0] +// CHECK: (%[[VAL_IN:.*]]: f16, %[[VAL_150:.*]]: i32, %[[VAL_INIT:.*]]: f16, %[[VAL_151:.*]]: i32) { +// CHECK: %[[VAL_264:.*]] = arith.cmpf olt, %[[VAL_IN]], %[[VAL_INIT]] : f16 +// CHECK: %[[VAL_265:.*]] = arith.cmpf oeq, %[[VAL_IN]], %[[VAL_INIT]] : f16 +// CHECK: %[[VAL_266:.*]] = arith.cmpf une, %[[VAL_IN]], %[[VAL_IN]] : f16 +// CHECK: %[[VAL_267:.*]] = arith.cmpf une, %[[VAL_INIT]], %[[VAL_INIT]] : f16 +// CHECK: %[[VAL_268:.*]] = arith.xori %[[VAL_267]], %true : i1 +// CHECK: %[[VAL_269:.*]] = arith.andi %[[VAL_266]], %[[VAL_268]] : i1 +// CHECK: %[[VAL_270:.*]] = arith.ori %[[VAL_264]], %[[VAL_269]] : i1 +// CHECK: %[[VAL_271:.*]] = arith.andi %[[VAL_266]], %[[VAL_267]] : i1 +// CHECK: %[[VAL_272:.*]] = arith.ori %[[VAL_265]], %[[VAL_271]] : i1 +// CHECK: %[[VAL_273:.*]] = arith.cmpi slt, %[[VAL_150]], %[[VAL_151]] : i32 +// CHECK: %[[VAL_274:.*]] = arith.andi %[[VAL_272]], %[[VAL_273]] : i1 +// CHECK: %[[VAL_275:.*]] = arith.ori %[[VAL_270]], %[[VAL_274]] : i1 +// CHECK: %[[VAL_276:.*]] = arith.select %[[VAL_275]], %[[VAL_IN]], %[[VAL_INIT]] : f16 +// CHECK: %[[VAL_277:.*]] = arith.select %[[VAL_275]], %[[VAL_150]], %[[VAL_151]] : i32 +// CHECK: linalg.yield %[[VAL_276]], %[[VAL_277]] : f16, i32 +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/mului_extended.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/mului_extended.mlir new file mode 100644 index 000000000..c08bfb07e --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/mului_extended.mlir @@ -0,0 +1,25 @@ +// RUN: triton-adapter-opt %s --triton-to-linalg | FileCheck %s + +module { + tt.func public @umulhi_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.load %2 : tensor<128x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %6 = tt.load %5 : tensor<128x!tt.ptr> + %7 = tt.mulhiui %3, %6 : tensor<128xi32> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %9, %7 : tensor<128x!tt.ptr> + tt.return + } +} + +//CHECK-LABEL: @umulhi_kernel +//CHECK: %[[VAL0:.*]] = bufferization.to_tensor %alloc restrict writable : memref<128xi32> +//CHECK: %[[VAL1:.*]] = bufferization.to_tensor %alloc_1 restrict writable : memref<128xi32> +//CHECK: %[[VAL2:.*]], %[[VAL3:.*]] = arith.mulsi_extended %in, %in_3 : i32 + diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/offset_iter_arg.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/offset_iter_arg.mlir new file mode 100644 index 000000000..739845e5f --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/offset_iter_arg.mlir @@ -0,0 +1,77 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @test_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c3_i32 = arith.constant 3 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<128> : tensor<128xi32> + %cst_0 = arith.constant dense<0> : tensor<128xi32> + %cst_1 = arith.constant dense<300> : tensor<128xi32> + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %3 = tt.splat %1 : i32 -> tensor<128xi32> + %4 = arith.addi %3, %2 : tensor<128xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %7 = scf.for %arg2 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg3 = %4) -> (tensor<128xi32>) : i32 { + %8 = scf.for %arg4 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg5 = %arg3) -> (tensor<128xi32>) : i32 { + %10 = arith.cmpi slt, %arg5, %cst_1 : tensor<128xi32> + %11 = tt.addptr %5, %arg5 : tensor<128x!tt.ptr>, tensor<128xi32> + %12 = tt.load %11, %10, %cst_0 : tensor<128x!tt.ptr> + %13 = tt.addptr %6, %arg5 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %13, %12, %10 : tensor<128x!tt.ptr> + %14 = arith.addi %arg5, %cst : tensor<128xi32> + scf.yield %14 : tensor<128xi32> + } + %9 = arith.addi %8, %cst : tensor<128xi32> + scf.yield %9 : tensor<128xi32> + } + tt.return + } +} + +// CHECK-LABEL: func.func @test_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref, +// CHECK-SAME: %[[VAL_2:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, +// CHECK-SAME: %[[VAL_3:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, +// CHECK-SAME: %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32, %[[VAL_9:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_10:.*]] = arith.constant 300 : index +// CHECK: %[[VAL_11:.*]] = arith.constant 128 : index +// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_13:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_15:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_16:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_7]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index +// CHECK: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_15]] to %[[VAL_13]] step %[[VAL_14]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (index) : i32 { +// CHECK: %[[VAL_22:.*]] = scf.for %[[VAL_23:.*]] = %[[VAL_15]] to %[[VAL_13]] step %[[VAL_14]] iter_args(%[[VAL_24:.*]] = %[[VAL_21]]) -> (index) : i32 { +// CHECK: %[[VAL_25:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_24]]], sizes: [128], strides: {{\[}}%[[VAL_12]]] : memref to memref<128xi32, strided<[?], offset: ?>> +// CHECK: %[[VAL_26:.*]] = memref.alloc() : memref<128xi32> +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_24]], %[[VAL_11]] : index +// CHECK: %[[VAL_28:.*]] = arith.maxsi %[[VAL_24]], %[[VAL_10]] : index +// CHECK: %[[VAL_29:.*]] = arith.minsi %[[VAL_27]], %[[VAL_28]] : index +// CHECK: %[[VAL_30:.*]] = arith.subi %[[VAL_29]], %[[VAL_24]] : index +// CHECK: %[[VAL_31:.*]] = arith.cmpi slt, %[[VAL_30]], %[[VAL_11]] : index +// CHECK: scf.if %[[VAL_31]] { +// CHECK: linalg.fill ins(%[[VAL_15]] : i32) outs(%[[VAL_26]] : memref<128xi32>) +// CHECK: } +// CHECK: %[[VAL_32:.*]] = memref.subview %[[VAL_25]][0] {{\[}}%[[VAL_30]]] [1] : memref<128xi32, strided<[?], offset: ?>> to memref> +// CHECK: %[[VAL_33:.*]] = memref.subview %[[VAL_26]][0] {{\[}}%[[VAL_30]]] [1] : memref<128xi32> to memref> +// CHECK: memref.copy %[[VAL_32]], %[[VAL_33]] : memref> to memref> +// CHECK: %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_26]] restrict writable : memref<128xi32> +// CHECK: %[[VAL_35:.*]] = memref.reinterpret_cast %[[VAL_3]] to offset: {{\[}}%[[VAL_24]]], sizes: [128], strides: {{\[}}%[[VAL_12]]] : memref to memref<128xi32, strided<[?], offset: ?>> +// CHECK: %[[VAL_36:.*]] = tensor.extract_slice %[[VAL_34]][0] {{\[}}%[[VAL_30]]] [1] : tensor<128xi32> to tensor +// CHECK: %[[VAL_37:.*]] = memref.subview %[[VAL_35]][0] {{\[}}%[[VAL_30]]] [1] : memref<128xi32, strided<[?], offset: ?>> to memref> +// CHECK: bufferization.materialize_in_destination %[[VAL_36]] in writable %[[VAL_37]] : (tensor, memref>) -> () +// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_24]], %[[VAL_11]] : index +// CHECK: scf.yield %[[VAL_38]] : index +// CHECK: } +// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_22]], %[[VAL_11]] : index +// CHECK: scf.yield %[[VAL_39]] : index +// CHECK: } +// CHECK: return +// CHECK: } \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/permute_3d.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/permute_3d.mlir new file mode 100644 index 000000000..0e824e6a5 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/permute_3d.mlir @@ -0,0 +1,26 @@ +// RUN: triton-adapter-opt --triton-to-annotation --triton-to-linalg %s | FileCheck %s +module { + // CHECK-LABEL: func @permute_3d + tt.func public @permute_3d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c12_i64 = arith.constant 12 : i64 + %c512_i64 = arith.constant 512 : i64 + %c512_i32 = arith.constant 512 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c512_i32: i32 + %2 = tt.make_tensor_ptr %arg1, [%c12_i64, %c512_i64], [%c512_i64, %c1_i64], [%c0_i32, %1] {order = array} : !tt.ptr> + %3 = tt.load %2 : !tt.ptr> + // CHECK-NOT: annotation.mark %[[LOADED:.*]] {MayImplicitTransposeWithLastAxis} : tensor<512x12xf16> + %4 = tt.reshape %3 : tensor<12x512xf16> -> tensor<12x4x128xf16> + // CHECK: %[[RES:.*]] = tensor.empty() : tensor<4x12x128xf16> + // CHECK: %[[TRANS:.*]] = linalg.transpose ins(%[[SRC:.*]] : tensor<12x4x128xf16>) outs(%[[RES]] : tensor<4x12x128xf16>) permutation = [1, 0, 2] + %5 = tt.trans %4 {order = array} : tensor<12x4x128xf16> -> tensor<4x12x128xf16> + %6 = tt.reshape %5 : tensor<4x12x128xf16> -> tensor<6144xf16> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<6144x!tt.ptr> + %8 = tt.make_range {end = 6144 : i32, start = 0 : i32} : tensor<6144xi32> + %9 = tt.addptr %7, %8 : tensor<6144x!tt.ptr>, tensor<6144xi32> + tt.store %9, %6 evictionPolicy = evict_last : tensor<6144x!tt.ptr> + tt.return + } +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/precise_div.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/precise_div.mlir new file mode 100644 index 000000000..493917f3c --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/precise_div.mlir @@ -0,0 +1,37 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @triton_divRn(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c2048_i32 = arith.constant 2048 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c2048_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<64x!tt.ptr> + scf.for %arg3 = %c0_i32 to %c32_i32 step %c1_i32 : i32 { + %6 = arith.muli %arg3, %c64_i32 : i32 + %7 = arith.addi %1, %6 : i32 + %8 = tt.splat %7 : i32 -> tensor<64xi32> + %9 = arith.addi %8, %2 : tensor<64xi32> + %10 = tt.addptr %3, %9 : tensor<64x!tt.ptr>, tensor<64xi32> + %11 = tt.load %10 : tensor<64x!tt.ptr> + %12 = tt.addptr %4, %9 : tensor<64x!tt.ptr>, tensor<64xi32> + %13 = tt.load %12 : tensor<64x!tt.ptr> + %14 = tt.precise_divf %11, %13 : tensor<64xf32> + %15 = tt.addptr %5, %9 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %15, %14 : tensor<64x!tt.ptr> + } + tt.return + } +} + +//CHECK: %[[VAL0:.*]] = bufferization.to_tensor %alloc restrict writable : memref<64xf32> +//CHECK: %[[VAL1:.*]] = bufferization.to_tensor %alloc_1 restrict writable : memref<64xf32> +//CHECK: %[[VAL2:.*]] = arith.divf %in, %in_3 : f32 + + diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/precise_sqrt.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/precise_sqrt.mlir new file mode 100644 index 000000000..ccf4975d9 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/precise_sqrt.mlir @@ -0,0 +1,29 @@ +// RUN: triton-adapter-opt %s --triton-to-linalg | FileCheck %s + +module { + tt.func public @sqrtrn_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c512_i32 = arith.constant 512 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c512_i32 : i32 + %2 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> + %3 = tt.splat %1 : i32 -> tensor<512xi32> + %4 = arith.addi %3, %2 : tensor<512xi32> + %5 = tt.splat %arg3 : i32 -> tensor<512xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<512xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<512x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<512x!tt.ptr>, tensor<512xi32> + %9 = tt.load %8, %6 : tensor<512x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<512x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<512x!tt.ptr>, tensor<512xi32> + %12 = tt.load %11, %6 : tensor<512x!tt.ptr> + %13 = tt.precise_sqrt %12 : tensor<512xf32> + %14 = arith.addf %9, %13 : tensor<512xf32> + %15 = tt.splat %arg2 : !tt.ptr -> tensor<512x!tt.ptr> + %16 = tt.addptr %15, %4 : tensor<512x!tt.ptr>, tensor<512xi32> + tt.store %16, %14, %6 : tensor<512x!tt.ptr> + tt.return + } +} + + +//CHECK: %[[OUTPUT:.*]] = math.sqrt %[[INPUT:.*]] : f32 diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reduce_with_index_attr.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reduce_with_index_attr.mlir new file mode 100644 index 000000000..6597bc182 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reduce_with_index_attr.mlir @@ -0,0 +1,29 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @argmax_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> + %3 = tt.splat %1 : i32 -> tensor<4096xi32> + %4 = arith.addi %3, %2 : tensor<4096xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> + %7 = tt.load %6 : tensor<4096x!tt.ptr> + // CHECK: %[[REDUCED:.*]]:2 = linalg.reduce ins(%[[INPUT1:.*]], %[[INPUT2:.*]] : tensor<4096xf32>, tensor<4096xi32>) outs(%[[OUTPUT1:.*]], %[[OUTPUT2:.*]] : tensor, tensor) dimensions = [0] {reduce_mode = "max_with_index"} + %8:2 = "tt.reduce"(%7, %2) <{axis = 0 : i32}> ({ + ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): + %11 = arith.cmpf ogt, %arg9, %arg11 : f32 + %12 = arith.cmpf oeq, %arg9, %arg11 : f32 + %13 = arith.cmpi slt, %arg10, %arg12 : i32 + %14 = arith.andi %12, %13 : i1 + %15 = arith.ori %11, %14 : i1 + %16 = arith.select %15, %arg9, %arg11 : f32 + %17 = arith.select %15, %arg10, %arg12 : i32 + tt.reduce.return %16, %17 : f32, i32 + }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) + %9 = tt.addptr %arg1, %0 : !tt.ptr, i32 + tt.store %9, %8#1 : !tt.ptr + tt.return + } +} \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reduceany_bool.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reduceany_bool.mlir new file mode 100644 index 000000000..5a40ee7ae --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reduceany_bool.mlir @@ -0,0 +1,84 @@ +// RUN: triton-adapter-opt -split-input-file --triton-to-linalg %s | FileCheck %s + +// CHECK-LABEL: func.func @kernel( +module { + tt.func public @kernel(%input : !tt.ptr, %output : !tt.ptr) + { + %cst = arith.constant dense<0> : tensor<128xi8> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %input : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.bitcast %2 : tensor<128x!tt.ptr> -> tensor<128x!tt.ptr> + %in = tt.load %3 : tensor<128x!tt.ptr> + %4 = arith.cmpi ne, %in, %cst : tensor<128xi8> + // CHECK: linalg.reduce + // CHECK: arith.ori + %5 = "tt.reduce"(%4) <{axis = 0 : i32}> ({ + ^bb0(%arg0: i1, %arg1: i1): + %6 = arith.ori %arg0, %arg1 : i1 + tt.reduce.return %6 : i1 + }) : (tensor<128xi1>) -> i1 + %7 = tt.bitcast %output : !tt.ptr -> !tt.ptr + %8 = tt.splat %7 : !tt.ptr -> tensor<1x!tt.ptr> + %9 = tt.splat %5 : i1 -> tensor<1xi1> + %10 = arith.extui %9 : tensor<1xi1> to tensor<1xi8> + tt.store %8, %10 : tensor<1x!tt.ptr> + tt.return + } +} + +// ----- +// CHECK-LABEL: func.func @kernel_reduceAnd( +module { + tt.func public @kernel_reduceAnd(%input : !tt.ptr, %output : !tt.ptr) + { + %cst = arith.constant dense<0> : tensor<128xi8> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %input : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.bitcast %2 : tensor<128x!tt.ptr> -> tensor<128x!tt.ptr> + %in = tt.load %3 : tensor<128x!tt.ptr> + %4 = arith.cmpi ne, %in, %cst : tensor<128xi8> + // CHECK: linalg.reduce + // CHECK: arith.andi + %5 = "tt.reduce"(%4) <{axis = 0 : i32}> ({ + ^bb0(%arg0: i1, %arg1: i1): + %6 = arith.andi %arg0, %arg1 : i1 + tt.reduce.return %6 : i1 + }) : (tensor<128xi1>) -> i1 + %7 = tt.bitcast %output : !tt.ptr -> !tt.ptr + %8 = tt.splat %7 : !tt.ptr -> tensor<1x!tt.ptr> + %9 = tt.splat %5 : i1 -> tensor<1xi1> + %10 = arith.extui %9 : tensor<1xi1> to tensor<1xi8> + tt.store %8, %10 : tensor<1x!tt.ptr> + tt.return + } +} + +// ----- +// CHECK-LABEL: func.func @kernel_reduceAnd( +module { + tt.func public @kernel_reduceAnd(%input : !tt.ptr, %output : !tt.ptr) + { + %cst = arith.constant dense<0> : tensor<128xi8> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %input : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.bitcast %2 : tensor<128x!tt.ptr> -> tensor<128x!tt.ptr> + %in = tt.load %3 : tensor<128x!tt.ptr> + %4 = arith.cmpi ne, %in, %cst : tensor<128xi8> + // CHECK: linalg.reduce + // CHECK: arith.xori + %5 = "tt.reduce"(%4) <{axis = 0 : i32}> ({ + ^bb0(%arg0: i1, %arg1: i1): + %6 = arith.xori %arg0, %arg1 : i1 + tt.reduce.return %6 : i1 + }) : (tensor<128xi1>) -> i1 + %7 = tt.bitcast %output : !tt.ptr -> !tt.ptr + %8 = tt.splat %7 : !tt.ptr -> tensor<1x!tt.ptr> + %9 = tt.splat %5 : i1 -> tensor<1xi1> + %10 = arith.extui %9 : tensor<1xi1> to tensor<1xi8> + tt.store %8, %10 : tensor<1x!tt.ptr> + tt.return + } +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir new file mode 100644 index 000000000..525839f8d --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir @@ -0,0 +1,60 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel(%afloat : !tt.ptr, + %res : tensor<256x16x!tt.ptr> + ) -> () { + // offset calculations + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %c256 = arith.constant 256 : i32 + %ct256 = tt.splat %c256 : i32 -> tensor<32xi32> + %ws = arith.muli %ct256, %0 : tensor<32xi32> + %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> + %m2 = tt.broadcast %1 : tensor<32x1xi32> -> tensor<32x256xi32> + %100 = tt.expand_dims %m2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> + %moff = tt.broadcast %100 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> + %33 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %34 = tt.expand_dims %33 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %k2 = tt.broadcast %34 : tensor<1x256xi32> -> tensor<32x256xi32> + %200 = tt.expand_dims %k2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> + %koff = tt.broadcast %200 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> + %23 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %24 = tt.expand_dims %23 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %n2 = tt.broadcast %24 : tensor<1x16xi32> -> tensor<256x16xi32> + %300 = tt.expand_dims %n2 {axis = 0 : i32} : tensor<256x16xi32> -> tensor<1x256x16xi32> + %noff = tt.broadcast %300 : tensor<1x256x16xi32> -> tensor<32x256x16xi32> + %mkoff = arith.addi %moff, %koff : tensor<32x256x16xi32> + %mknoff = arith.addi %mkoff, %noff : tensor<32x256x16xi32> + // afloat pointer + %8 = tt.splat %afloat : !tt.ptr -> tensor<32x256x16x!tt.ptr> + %9 = tt.addptr %8, %mknoff : tensor<32x256x16x!tt.ptr>, tensor<32x256x16xi32> + %afm = tt.load %9 : tensor<32x256x16x!tt.ptr> + %6 = "tt.reduce"(%afm) ({ + ^bb0(%arg5: bf16, %arg6: bf16): + %21 = arith.cmpf ogt, %arg5, %arg6 : bf16 + %22 = arith.select %21, %arg5, %arg6 : bf16 + tt.reduce.return %22 : bf16 + }) {axis = 0 : i32} : (tensor<32x256x16xbf16>) -> tensor<256x16xbf16> + tt.store %res, %6 : tensor<256x16x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref<256x16xbf16>, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) +// CHECK-SAME: attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { + +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [32, 256, 16], strides: [256, 1, 1] : memref to memref<32x256x16xbf16, strided<[256, 1, 1]>> +// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<32x256x16xbf16> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<32x256x16xbf16, strided<[256, 1, 1]>> to memref<32x256x16xbf16> +// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<32x256x16xbf16> +// CHECK: %[[VAL_10:.*]] = tensor.empty() : tensor<256x16xbf16> + +// CHECK: %[[VAL_12:.*]] = linalg.reduce ins(%[[VAL_9]] : tensor<32x256x16xbf16>) outs(%[[VAL_10]] : tensor<256x16xbf16>) dimensions = [0] {reduce_mode = "max_with_index"} +// CHECK: (%[[VAL_13:.*]]: bf16, %[[VAL_14:.*]]: bf16) { +// CHECK: %[[VAL_15:.*]] = arith.cmpf ogt, %[[VAL_13]], %[[VAL_14]] : bf16 +// CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_15]], %[[VAL_13]], %[[VAL_14]] : bf16 +// CHECK: linalg.yield %[[VAL_16]] : bf16 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[VAL_12]] in writable %[[VAL_1]] : (tensor<256x16xbf16>, memref<256x16xbf16>) -> () +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir new file mode 100644 index 000000000..fddf08121 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir @@ -0,0 +1,53 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel(%afloat : !tt.ptr, + %res : !tt.ptr + ) -> () { + // offset calculations + %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> + %c256 = arith.constant 256 : i32 + %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> + %ws = arith.muli %ct256, %0 : tensor<512xi32> + %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> + %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> + %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> + // afloat pointer + %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> + %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> + // res pointer + %18 = tt.splat %res : !tt.ptr -> tensor<256x!tt.ptr> + %19 = tt.addptr %18, %3 : tensor<256x!tt.ptr>, tensor<256xi32> + %afm = tt.load %9 : tensor<512x256x!tt.ptr> + %5 = "tt.reduce"(%afm) ({ + ^bb0(%arg5: bf16, %arg6: bf16): + %21 = arith.addf %arg5, %arg6 : bf16 + tt.reduce.return %21 : bf16 + }) {axis = 0 : i32} : (tensor<512x256xbf16>) -> tensor<256xbf16> + tt.store %19, %5 : tensor<256x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) +// CHECK-SAME: attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { + +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: [256, 1] : memref to memref<512x256xbf16, strided<[256, 1]>> +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [256], strides: [1] : memref to memref<256xbf16, strided<[1]>> +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xbf16> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xbf16, strided<[256, 1]>> to memref<512x256xbf16> +// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xbf16> +// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256xbf16> +// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_11]] : tensor<256xbf16>) -> tensor<256xbf16> +// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xbf16>) outs(%[[VAL_12]] : tensor<256xbf16>) dimensions = [0] +// CHECK: (%[[VAL_14:.*]]: bf16, %[[VAL_15:.*]]: bf16) { +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : bf16 +// CHECK: linalg.yield %[[VAL_16]] : bf16 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_8]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir new file mode 100644 index 000000000..39629df06 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir @@ -0,0 +1,52 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel(%afloat : !tt.ptr, + %res : !tt.ptr + ) -> () { + // offset calculations + %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> + %c256 = arith.constant 256 : i32 + %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> + %ws = arith.muli %ct256, %0 : tensor<512xi32> + %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> + %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> + %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> + // afloat pointer + %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> + %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> + // res pointer + %18 = tt.splat %res : !tt.ptr -> tensor<512x!tt.ptr> + %19 = tt.addptr %18, %0 : tensor<512x!tt.ptr>, tensor<512xi32> + %afm = tt.load %9 : tensor<512x256x!tt.ptr> + %5 = "tt.reduce"(%afm) ({ + ^bb0(%arg5: bf16, %arg6: bf16): + %21 = arith.addf %arg5, %arg6 : bf16 + tt.reduce.return %21 : bf16 + }) {axis = 1 : i32} : (tensor<512x256xbf16>) -> tensor<512xbf16> + tt.store %19, %5 : tensor<512x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) +// CHECK-SAME: attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: [256, 1] : memref to memref<512x256xbf16, strided<[256, 1]>> +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [512], strides: [1] : memref to memref<512xbf16, strided<[1]>> +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xbf16> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xbf16, strided<[256, 1]>> to memref<512x256xbf16> +// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xbf16> +// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<512xbf16> +// CHECK: %[[VAL_14:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_11]] : tensor<512xbf16>) -> tensor<512xbf16> +// CHECK: %[[VAL_15:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xbf16>) outs(%[[VAL_14]] : tensor<512xbf16>) dimensions = [1] +// CHECK: (%[[VAL_16:.*]]: bf16, %[[VAL_17:.*]]: bf16) { +// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : bf16 +// CHECK: linalg.yield %[[VAL_18]] : bf16 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_8]] : (tensor<512xbf16>, memref<512xbf16, strided<[1]>>) -> () +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir new file mode 100644 index 000000000..40d91ae8e --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir @@ -0,0 +1,52 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel(%afloat : !tt.ptr, + %res : !tt.ptr + ) -> () { + // offset calculations + %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> + %c256 = arith.constant 256 : i32 + %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> + %ws = arith.muli %ct256, %0 : tensor<512xi32> + %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> + %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> + %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> + // afloat pointer + %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> + %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> + // res pointer + %18 = tt.splat %res : !tt.ptr -> tensor<256x!tt.ptr> + %19 = tt.addptr %18, %3 : tensor<256x!tt.ptr>, tensor<256xi32> + %afm = tt.load %9 : tensor<512x256x!tt.ptr> + %5 = "tt.reduce"(%afm) ({ + ^bb0(%arg5: f32, %arg6: f32): + %21 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %21 : f32 + }) {axis = 0 : i32} : (tensor<512x256xf32>) -> tensor<256xf32> + tt.store %19, %5 : tensor<256x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) +// CHECK-SAME: attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: [256, 1] : memref to memref<512x256xf32, strided<[256, 1]>> +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1]>> +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xf32> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xf32, strided<[256, 1]>> to memref<512x256xf32> +// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xf32> +// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256xf32> +// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_11]] : tensor<256xf32>) -> tensor<256xf32> +// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xf32>) outs(%[[VAL_12]] : tensor<256xf32>) dimensions = [0] +// CHECK: (%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32) { +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: linalg.yield %[[VAL_16]] : f32 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_8]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir new file mode 100644 index 000000000..39ab4ce39 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir @@ -0,0 +1,53 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel(%afloat : !tt.ptr, + %res : !tt.ptr + ) -> () { + // offset calculations + %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> + %c256 = arith.constant 256 : i32 + %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> + %ws = arith.muli %ct256, %0 : tensor<512xi32> + %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> + %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> + %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> + // afloat pointer + %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> + %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> + // res pointer + %18 = tt.splat %res : !tt.ptr -> tensor<512x!tt.ptr> + %19 = tt.addptr %18, %0 : tensor<512x!tt.ptr>, tensor<512xi32> + %afm = tt.load %9 : tensor<512x256x!tt.ptr> + %5 = "tt.reduce"(%afm) ({ + ^bb0(%arg5: f32, %arg6: f32): + %21 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %21 : f32 + }) {axis = 1 : i32} : (tensor<512x256xf32>) -> tensor<512xf32> + tt.store %19, %5 : tensor<512x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) +// CHECK-SAME: attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: [256, 1] : memref to memref<512x256xf32, strided<[256, 1]>> +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [512], strides: [1] : memref to memref<512xf32, strided<[1]>> +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xf32> +// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xf32, strided<[256, 1]>> to memref<512x256xf32> +// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xf32> + +// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<512xf32> +// CHECK: %[[VAL_14:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_13]] : tensor<512xf32>) -> tensor<512xf32> +// CHECK: %[[VAL_15:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xf32>) outs(%[[VAL_14]] : tensor<512xf32>) dimensions = [1] +// CHECK: (%[[VAL_16:.*]]: f32, %[[VAL_17:.*]]: f32) { +// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : f32 +// CHECK: linalg.yield %[[VAL_18]] : f32 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_8]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir new file mode 100644 index 000000000..bf315feae --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir @@ -0,0 +1,59 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel(%afloat : !tt.ptr, + %res : !tt.ptr, + %out: tensor<32x16x!tt.ptr> + ) -> () { + // offset calculations + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %c256 = arith.constant 256 : i32 + %ct256 = tt.splat %c256 : i32 -> tensor<32xi32> + %ws = arith.muli %ct256, %0 : tensor<32xi32> + %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> + %m2 = tt.broadcast %1 : tensor<32x1xi32> -> tensor<32x256xi32> + %100 = tt.expand_dims %m2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> + %moff = tt.broadcast %100 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> + %33 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %34 = tt.expand_dims %33 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %k2 = tt.broadcast %34 : tensor<1x256xi32> -> tensor<32x256xi32> + %200 = tt.expand_dims %k2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> + %koff = tt.broadcast %200 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> + %23 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %24 = tt.expand_dims %23 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> + %n2 = tt.broadcast %24 : tensor<1x16xi32> -> tensor<256x16xi32> + %300 = tt.expand_dims %n2 {axis = 0 : i32} : tensor<256x16xi32> -> tensor<1x256x16xi32> + %noff = tt.broadcast %300 : tensor<1x256x16xi32> -> tensor<32x256x16xi32> + %mkoff = arith.addi %moff, %koff : tensor<32x256x16xi32> + %mknoff = arith.addi %mkoff, %noff : tensor<32x256x16xi32> + // afloat pointer + %8 = tt.splat %afloat : !tt.ptr -> tensor<32x256x16x!tt.ptr> + %9 = tt.addptr %8, %mknoff : tensor<32x256x16x!tt.ptr>, tensor<32x256x16xi32> + %afm = tt.load %9 : tensor<32x256x16x!tt.ptr> + %5 = "tt.reduce"(%afm) ({ + ^bb0(%arg5: bf16, %arg6: bf16): + %21 = arith.addf %arg5, %arg6 : bf16 + tt.reduce.return %21 : bf16 + }) {axis = 1 : i32} : (tensor<32x256x16xbf16>) -> tensor<32x16xbf16> + tt.store %out, %5 : tensor<32x16x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref, %[[VAL_2:.*]]: memref<32x16xbf16>, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { + +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [32, 256, 16], strides: [256, 1, 1] : memref to memref<32x256x16xbf16, strided<[256, 1, 1]>> +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<32x256x16xbf16> +// CHECK: memref.copy %[[VAL_8]], %[[VAL_9]] : memref<32x256x16xbf16, strided<[256, 1, 1]>> to memref<32x256x16xbf16> +// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<32x256x16xbf16> +// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<32x16xbf16> +// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_7]] : bf16) outs(%[[VAL_11]] : tensor<32x16xbf16>) -> tensor<32x16xbf16> +// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<32x256x16xbf16>) outs(%[[VAL_12]] : tensor<32x16xbf16>) dimensions = [1] +// CHECK: (%[[VAL_14:.*]]: bf16, %[[VAL_15:.*]]: bf16) { +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : bf16 +// CHECK: linalg.yield %[[VAL_16]] : bf16 +// CHECK: } +// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_2]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_scalar.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_scalar.mlir new file mode 100644 index 000000000..ed0ed2b03 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_scalar.mlir @@ -0,0 +1,40 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr) + { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %afloat : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %afm = tt.load %2 : tensor<128x!tt.ptr> + %3 = "tt.reduce"(%afm) ({ + ^bb0(%arg5: bf16, %arg6: bf16): + %21 = arith.addf %arg5, %arg6 : bf16 + tt.reduce.return %21 : bf16 + }) {axis = 0 : i32} : (tensor<128xbf16>) -> bf16 + tt.store %res, %3 : !tt.ptr + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref {tt.tensor_kind = 0 : i32}, %[[VAL_1:.*]]: memref {tt.tensor_kind = 1 : i32}, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) +// CHECK-SAME: attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[VAL_6:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128], strides: [1] : memref to memref<128xbf16, strided<[1]>> +// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<128xbf16> +// CHECK: memref.copy %[[VAL_6]], %[[VAL_7]] : memref<128xbf16, strided<[1]>> to memref<128xbf16> +// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_7]] restrict writable : memref<128xbf16> +// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor +// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_5]] : bf16) outs(%[[VAL_9]] : tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<128xbf16>) outs(%[[VAL_10]] : tensor) dimensions = [0] +// CHECK: (%[[VAL_12:.*]]: bf16, %[[VAL_13:.*]]: bf16) { +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : bf16 +// CHECK: linalg.yield %[[VAL_15]] : bf16 +// CHECK: } +// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_11]][] : tensor +// CHECK: %[[VAL_18:.*]] = tensor.empty() : tensor<1xbf16> +// CHECK: %[[VAL_19:.*]] = linalg.fill ins(%[[VAL_16]] : bf16) outs(%[[VAL_18]] : tensor<1xbf16>) -> tensor<1xbf16> +// CHECK: %[[VAL_20:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xbf16, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_20]] : (tensor<1xbf16>, memref<1xbf16, strided<[1]>>) -> () +// CHECK: return +// CHECK: } \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_unit_dim.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_unit_dim.mlir new file mode 100644 index 000000000..bf7aaf452 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reducesum_unit_dim.mlir @@ -0,0 +1,67 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s --split-input-file | FileCheck %s -check-prefixes=CHECK +module { + // CHECK-LABEL: @triton_addptr_f32 + tt.func public @triton_addptr_f32(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32> + %cst_1 = arith.constant dense<256> : tensor<256xi32> + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %1 = arith.cmpi slt, %0, %cst_1 : tensor<256xi32> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %3 = tt.addptr %2, %0 : tensor<256x!tt.ptr>, tensor<256xi32> + %4 = tt.load %3, %1, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr> + %5 = arith.select %1, %4, %cst_0 : tensor<256xi1>, tensor<256xf32> + %6 = "tt.reduce"(%5) <{axis = 0 : i32}> ({ + ^bb0(%arg3: f32, %arg4: f32): + %11 = arith.addf %arg3, %arg4 : f32 + tt.reduce.return %11 : f32 + }) : (tensor<256xf32>) -> f32 + %7 = arith.addf %6, %cst : f32 + //CHECK: %[[TENSOR:.*]] = tensor.empty() : tensor<1xf32> + //CHECK: %[[RES:.*]] = linalg.fill ins(%[[VAL:.*]] : f32) outs(%[[TENSOR]] : tensor<1xf32>) -> tensor<1xf32> + %8 = tt.splat %7 : f32 -> tensor<1xf32> + //CHECK: %[[OUTPTR:.*]] = memref.reinterpret_cast %[[ARG_3:.*]] to offset: [0], sizes: [1], strides: [1] : memref to memref<1xf32, strided<[1]>> + %9 = tt.addptr %arg1, %c0_i32 : !tt.ptr, i32 + %10 = tt.splat %9 : !tt.ptr -> tensor<1x!tt.ptr> + //CHECK: bufferization.materialize_in_destination %[[VAL_1:.*]] in writable %[[OUTPTR]] : (tensor<1xf32>, memref<1xf32, strided<[1]>>) -> () + tt.store %10, %8 {cache = 1 : i32, evict = 1 : i32} : tensor<1x!tt.ptr> + tt.return + } +} + +// ----- + +module { + // CHECK-LABEL: @triton_addptr_1x1xf32 + tt.func public @triton_addptr_1x1xf32(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32> + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<1x256x!tt.ptr> + %3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c256_i32 iter_args(%arg4 = %cst) -> (tensor<1x256xf32>) : i32 { + %8 = tt.splat %arg3 : i32 -> tensor<1x256xi32> + %9 = arith.addi %8, %1 : tensor<1x256xi32> + %10 = tt.addptr %2, %9 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + %11 = tt.load %10 evictionPolicy = evict_first : tensor<1x256x!tt.ptr> + %12 = arith.addf %arg4, %11 : tensor<1x256xf32> + scf.yield %12 : tensor<1x256xf32> + } + %4 = "tt.reduce"(%3) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32, %arg4: f32): + %8 = arith.addf %arg3, %arg4 : f32 + tt.reduce.return %8 : f32 + }) : (tensor<1x256xf32>) -> tensor<1xf32> + //CHECK: %[[RES:.*]] = tensor.expand_shape %[[VAL:.*]] {{\[\[}}0, 1]] output_shape {{\[}}1, 1] : tensor<1xf32> into tensor<1x1xf32> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<1xf32> -> tensor<1x1xf32> + //CHECK: %[[OUTPTR:.*]] = memref.reinterpret_cast %[[ARG_1:.*]] to offset: [0], sizes: [1, 1], strides: [1, 1] : memref to memref<1x1xf32, strided<[1, 1]>> + %6 = tt.addptr %arg1, %c0_i32 : !tt.ptr, i32 + %7 = tt.splat %6 : !tt.ptr -> tensor<1x1x!tt.ptr> + //CHECK: bufferization.materialize_in_destination %[[RES]] in writable %[[OUTPTR]] : (tensor<1x1xf32>, memref<1x1xf32, strided<[1, 1]>>) -> () + tt.store %7, %5 : tensor<1x1x!tt.ptr> + tt.return + } +} \ No newline at end of file diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/reshape.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/reshape.mlir new file mode 100644 index 000000000..1679b1416 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/reshape.mlir @@ -0,0 +1,25 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func public @auto_gen_kernel_01(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant dense<512> : tensor<1xi64> +// CHECK: %[[VAL_1:.*]] = arith.constant dense<[4, 4, 32]> : tensor<3xi64> + %c512_i32 = arith.constant 512 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c512_i32 : i32 + %2 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> + %3 = tt.splat %1 : i32 -> tensor<512xi32> + %4 = arith.addi %3, %2 : tensor<512xi32> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<512x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<512x!tt.ptr>, tensor<512xi32> + %7 = tt.load %6 evictionPolicy = evict_last : tensor<512x!tt.ptr> +// CHECK: %[[VAL_2:.*]] = tensor.reshape %[[VAL_0:.*]](%[[VAL_1:.*]]) : (tensor<512xf64>, tensor<3xi64>) -> tensor<4x4x32xf64> + %8 = tt.reshape %7 : tensor<512xf64> -> tensor<4x4x32xf64> + %9 = arith.addf %8, %8 : tensor<4x4x32xf64> +// CHECK: %[[VAL_7:.*]] = tensor.reshape %[[VAL_3:.*]](%[[VAL_4]]) : (tensor<4x4x32xf64>, tensor<1xi64>) -> tensor<512xf64> + %10 = tt.reshape %9 : tensor<4x4x32xf64> -> tensor<512xf64> + %11 = tt.splat %arg0 : !tt.ptr -> tensor<512x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<512x!tt.ptr>, tensor<512xi32> + tt.store %12, %10 evictionPolicy = evict_last : tensor<512x!tt.ptr> + tt.return + } +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/scalar_cl_extern_elementwise.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/scalar_cl_extern_elementwise.mlir new file mode 100644 index 000000000..56e754eb2 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/scalar_cl_extern_elementwise.mlir @@ -0,0 +1,20 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @triton_unk_fused_pow_0(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %0 = tt.addptr %arg0, %c0_i32 : !tt.ptr, i32 + %1 = tt.load %0 : !tt.ptr + %2 = arith.extf %1 : f16 to f32 + // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INPUT:.*]] : tensor<1xf32> + // CHECK: %[[MAPPED:.*]] = linalg.map { func.call {callee = @__hmf_sqrtf} } ins(%[[TENSOR]] : tensor<1xf32>) outs(%[[TENSOR]] : tensor<1xf32>) + // CHECK: %[[VAR_EXTRACETED:.+]] = tensor.extract %[[MAPPED]][%[[C0:.+]]] : tensor<1xf32> + %3 = tt.extern_elementwise %2 {libname = "", libpath = "", pure = true, symbol = "__hmf_sqrtf"} : (f32) -> f32 + %4 = tt.addptr %arg1, %c0_i32 : !tt.ptr, i32 + %5 = tt.splat %4 : !tt.ptr -> tensor<1x!tt.ptr> + %6 = arith.truncf %3 : f32 to f16 + %7 = tt.splat %6 : f16 -> tensor<1xf16> + tt.store %5, %7 : tensor<1x!tt.ptr> + tt.return + } +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/scf_if_return.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/scf_if_return.mlir new file mode 100644 index 000000000..04d77fb6a --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/scf_if_return.mlir @@ -0,0 +1,42 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func public @triton_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c52_i32 = arith.constant 52 : i32 + %cst = arith.constant dense<0> : tensor<1x1xi32> + %c50_i32 = arith.constant 50 : i32 + %c2_i32 = arith.constant 2 : i32 + %c25_i32 = arith.constant 25 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.cmpi slt, %0, %c25_i32 : i32 + %2:2 = scf.if %1 -> (i32, i32) { + // CHECK: %[[RET:.*]]:2 = scf.if %[[COND:.*]] -> (i32, i32) { + %3 = arith.muli %0, %c2_i32 : i32 + %4 = arith.addi %3, %c2_i32 : i32 + scf.yield %3, %4 : i32, i32 + } else { + %3 = arith.subi %0, %c25_i32 : i32 + %4 = arith.muli %3, %c2_i32 : i32 + %5 = arith.addi %4, %c50_i32 : i32 + %6 = arith.addi %4, %c52_i32 : i32 + scf.yield %5, %6 : i32, i32 + } + scf.for %arg5 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { + %3 = arith.addi %2#0, %arg5 : i32 + %4 = arith.cmpi slt, %3, %2#1 : i32 + %5 = tt.splat %4 : i1 -> tensor<1x1xi1> + %6 = tt.addptr %arg0, %3 : !tt.ptr, i32 + %7 = tt.splat %6 : !tt.ptr -> tensor<1x1x!tt.ptr> + %8 = tt.load %7, %5, %cst : tensor<1x1x!tt.ptr> + %9 = tt.addptr %arg1, %3 : !tt.ptr, i32 + %10 = tt.splat %9 : !tt.ptr -> tensor<1x1x!tt.ptr> + %11 = tt.load %10, %5, %cst : tensor<1x1x!tt.ptr> + %12 = arith.addi %8, %11 : tensor<1x1xi32> + %13 = tt.addptr %arg2, %3 : !tt.ptr, i32 + %14 = tt.splat %13 : !tt.ptr -> tensor<1x1x!tt.ptr> + tt.store %14, %12, %5 : tensor<1x1x!tt.ptr> + } + tt.return + } +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/split.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/split.mlir new file mode 100644 index 000000000..940a6d7c8 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/split.mlir @@ -0,0 +1,46 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +tt.func public @fn_npu_split(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<256> : tensor<16x1xi32> + %cst_0 = arith.constant dense<2> : tensor<1x256x1xi32> + %cst_1 = arith.constant dense<2> : tensor<16x1x1xi32> + %cst_2 = arith.constant dense<256> : tensor<16x1x1xi32> + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %2 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %3 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %4 = tt.expand_dims %3 {axis = 2 : i32} : tensor<16x1xi32> -> tensor<16x1x1xi32> + %5 = arith.muli %4, %cst_2 : tensor<16x1x1xi32> + %6 = arith.muli %5, %cst_1 : tensor<16x1x1xi32> + %7 = tt.expand_dims %1 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %8 = tt.expand_dims %7 {axis = 2 : i32} : tensor<1x256xi32> -> tensor<1x256x1xi32> + %9 = arith.muli %8, %cst_0 : tensor<1x256x1xi32> + %10 = tt.broadcast %6 : tensor<16x1x1xi32> -> tensor<16x256x1xi32> + %11 = tt.broadcast %9 : tensor<1x256x1xi32> -> tensor<16x256x1xi32> + %12 = arith.addi %10, %11 : tensor<16x256x1xi32> + %13 = tt.expand_dims %2 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> + %14 = tt.expand_dims %13 {axis = 1 : i32} : tensor<1x2xi32> -> tensor<1x1x2xi32> + %15 = tt.broadcast %12 : tensor<16x256x1xi32> -> tensor<16x256x2xi32> + %16 = tt.broadcast %14 : tensor<1x1x2xi32> -> tensor<16x256x2xi32> + %17 = arith.addi %15, %16 : tensor<16x256x2xi32> + %18 = tt.splat %arg1 : !tt.ptr -> tensor<16x256x2x!tt.ptr> + %19 = tt.addptr %18, %17 : tensor<16x256x2x!tt.ptr>, tensor<16x256x2xi32> + %20 = tt.load %19 : tensor<16x256x2x!tt.ptr> + %outLHS, %outRHS = tt.split %20 : tensor<16x256x2xf32> -> tensor<16x256xf32> + %21 = arith.muli %3, %cst : tensor<16x1xi32> + %22 = tt.broadcast %21 : tensor<16x1xi32> -> tensor<16x256xi32> + %23 = tt.broadcast %7 : tensor<1x256xi32> -> tensor<16x256xi32> + %24 = arith.addi %22, %23 : tensor<16x256xi32> + %25 = tt.splat %arg0 : !tt.ptr -> tensor<16x256x!tt.ptr> + %26 = tt.addptr %25, %24 : tensor<16x256x!tt.ptr>, tensor<16x256xi32> + tt.store %26, %outLHS : tensor<16x256x!tt.ptr> + %27 = tt.splat %arg4 : !tt.ptr -> tensor<16x256x!tt.ptr> + %28 = tt.addptr %27, %24 : tensor<16x256x!tt.ptr>, tensor<16x256xi32> + tt.store %28, %outRHS : tensor<16x256x!tt.ptr> + tt.return +} + +//CHECK-LABEL: @fn_npu_split +//CHECK-NOT: tt.split +//CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ADDR:.*]] restrict writable : memref<16x256x2xf32> +//CHECK: %[[EXT0:.*]] = tensor.extract_slice %[[VAL0]][0, 0, 0] [16, 256, 1] [1, 1, 2] : tensor<16x256x2xf32> to tensor<16x256xf32> +//CHECK: %[[EXT1:.*]] = tensor.extract_slice %[[VAL0]][0, 0, 1] [16, 256, 1] [1, 1, 2] : tensor<16x256x2xf32> to tensor<16x256xf32> diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/trig_atan.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/trig_atan.mlir new file mode 100644 index 000000000..8836ebf69 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/trig_atan.mlir @@ -0,0 +1,31 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + + +module { + tt.func public @triton_atan(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %3 = tt.splat %arg3 : i32 -> tensor<64xi32> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<64x!tt.ptr> + scf.for %arg4 = %c0_i32 to %c128_i32 step %c64_i32 : i32 { + %6 = arith.addi %1, %arg4 : i32 + %7 = tt.splat %6 : i32 -> tensor<64xi32> + %8 = arith.addi %7, %2 : tensor<64xi32> + %9 = arith.cmpi slt, %8, %3 : tensor<64xi32> + %10 = tt.addptr %4, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + %11 = tt.load %10, %9 : tensor<64x!tt.ptr> + %12 = tt.extern_elementwise %11 {libname = "", libpath = "", pure = true, symbol = "__hmf_atanf"} : (tensor<64xf32>) -> tensor<64xf32> + %13 = tt.addptr %5, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %13, %12, %9 : tensor<64x!tt.ptr> + } + tt.return + } +} + + +//CHECK: %mapped = linalg.map { func.call {callee = @__hmf_atanf} } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/trig_tan.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/trig_tan.mlir new file mode 100644 index 000000000..e96aa750e --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/trig_tan.mlir @@ -0,0 +1,29 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @triton_tan(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32768_i32 = arith.constant 32768 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32768_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %arg3 : i32 -> tensor<1024xi32> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> + scf.for %arg4 = %c0_i32 to %c32768_i32 step %c1024_i32 : i32 { + %6 = arith.addi %1, %arg4 : i32 + %7 = tt.splat %6 : i32 -> tensor<1024xi32> + %8 = arith.addi %7, %2 : tensor<1024xi32> + %9 = arith.cmpi slt, %8, %3 : tensor<1024xi32> + %10 = tt.addptr %4, %8 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %11 = tt.load %10, %9 : tensor<1024x!tt.ptr> + %12 = tt.extern_elementwise %11 {libname = "", libpath = "", pure = true, symbol = "__hmf_tanf"} : (tensor<1024xf32>) -> tensor<1024xf32> + %13 = tt.addptr %5, %8 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %13, %12, %9 : tensor<1024x!tt.ptr> + } + tt.return + } +} + +//CHECK: %mapped = linalg.map { func.call {callee = @__hmf_tanf} } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/triton_assert.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/triton_assert.mlir new file mode 100644 index 000000000..0d7c8f09d --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/triton_assert.mlir @@ -0,0 +1,12 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +tt.func public @assert_lol(%arg0: i32) { + %c0_i32 = arith.constant 0 : i32 + %0 = arith.cmpi sgt, %arg0, %c0_i32 : i32 + %1 = tt.splat %0 : i1 -> tensor<1xi1> + tt.assert %1, "lol" : tensor<1xi1> + tt.return +} + +// CHECK: (%arg0: memref, %arg1: memref, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir new file mode 100644 index 000000000..cab426e50 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir @@ -0,0 +1,37 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @rand(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = tt.extern_elementwise %3, %0 {libname = "", libpath = "", pure = true, symbol = "some_symbol"} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @rand( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) +// CHECK-SAME: attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: [[VAR_0_:%.+]] = tensor.empty() : tensor<8xi32> +// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<8xi32>) { +// CHECK: ^bb0([[out:.+]]: i32): +// CHECK: [[VAR_4_:%.+]] = linalg.index 0 : index +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i32 +// CHECK: linalg.yield [[VAR_5_]] : i32 +// CHECK: } -> tensor<8xi32> +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [8], strides: [1] : memref to memref<8xi32, strided<[1]>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<8xi32> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<8xi32, strided<[1]>> to memref<8xi32> +// CHECK: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<8xi32> +// CHECK-DAG: [[VAR_3_:%.+]] = tt.extern_elementwise [[VAR_2_]], [[VAR_1_]] {libname = "", libpath = "", pure = true, symbol = "some_symbol"} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [8], strides: [1] : memref to memref<8xi32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination [[VAR_3_]] in writable [[VAR_reinterpret_cast_0_]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/use_dot_opc.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/use_dot_opc.mlir new file mode 100644 index 000000000..8c3db5dd1 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/use_dot_opc.mlir @@ -0,0 +1,74 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr + ) + { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %c64 = arith.constant 128 : i32 + %1 = tt.splat %c64 : i32 -> tensor<128xi32> + %2 = arith.muli %0, %1 : tensor<128xi32> + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %4 = tt.broadcast %3 : tensor<128x1xi32> -> tensor<128x64xi32> + %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %7 = tt.broadcast %6 : tensor<1x64xi32> -> tensor<128x64xi32> + %8 = arith.addi %4, %7 : tensor<128x64xi32> + %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %12 = tt.broadcast %11 : tensor<1x256xi32> -> tensor<64x256xi32> + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %c256 = arith.constant 256 : i32 + %14 = tt.splat %c256 : i32 -> tensor<64xi32> + %15 = arith.muli %13, %14 : tensor<64xi32> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %17 = tt.broadcast %16 : tensor<64x1xi32> -> tensor<64x256xi32> + %18 = arith.addi %12, %17 : tensor<64x256xi32> + %20 = tt.splat %c256 : i32 -> tensor<128xi32> + %21 = arith.muli %0, %20 : tensor<128xi32> + %22 = tt.expand_dims %21 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %23 = tt.broadcast %22 : tensor<128x1xi32> -> tensor<128x256xi32> + %24 = tt.expand_dims %10 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %25 = tt.broadcast %24 {axis = 0 : i32} : tensor<1x256xi32> -> tensor<128x256xi32> + %26 = arith.addi %23, %25 : tensor<128x256xi32> + %30 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> + %31 = tt.addptr %30, %8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x64x!tt.ptr> + %40 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> + %41 = tt.addptr %40, %18 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + %42 = tt.load %41 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<64x256x!tt.ptr> + %50 = tt.splat %arg2 : !tt.ptr -> tensor<128x256x!tt.ptr> + %51 = tt.addptr %50, %26 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %cf0 = arith.constant 0.0 : bf16 + %71 = tt.splat %cf0 : bf16 -> tensor<128x256xbf16> + %60 = tt.dot %32, %42, %71 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> + tt.store %51, %60 : tensor<128x256x!tt.ptr> + tt.store %51, %71 : tensor<128x256x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: [[PARAM_0_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_1_:%.+]]: memref {tt.tensor_kind = 0 : i32}, [[PARAM_2_:%.+]]: memref {tt.tensor_kind = 1 : i32}, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) +// CHECK-SAME: attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "mix"} { +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x256xbf16> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_0_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 64], strides: [128, 1] : memref to memref<128x64xbf16, strided<[128, 1]>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<128x64xbf16, strided<[128, 1]>> to memref<128x64xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [64, 256], strides: [256, 1] : memref to memref<64x256xbf16, strided<[256, 1]>> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<64x256xbf16> +// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<64x256xbf16, strided<[256, 1]>> to memref<64x256xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<64x256xbf16> +// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 256], strides: [256, 1] : memref to memref<128x256xbf16, strided<[256, 1]>> +// CHECK: [[VAR_4_:%.+]] = linalg.matmul {input_precison = "ieee"} ins([[VAR_2_]], [[VAR_3_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_1_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK: bufferization.materialize_in_destination [[VAR_4_]] in writable [[VAR_reinterpret_cast_2_]] +// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_2_]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/use_end_chain.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/use_end_chain.mlir new file mode 100644 index 000000000..2b4cbeb6a --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/use_end_chain.mlir @@ -0,0 +1,75 @@ +// RUN: triton-adapter-opt --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr + ) + { + %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> + // offset = [512] size = 256, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> + // offset = [512,0], size = [256,1], stride = [1,0] + %2 = tt.broadcast %1 : tensor<256x1xi32> -> tensor<256x128xi32> + // offset = [512,0], size = [256,128], stride = [1,0] + %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> + // offset = 1024, size = 128, stride = 1 + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + // offset = [0,1024], size = [1,128], stride = [0,1] + %7 = tt.broadcast %6 : tensor<1x128xi32> -> tensor<256x128xi32> + // offset = [0,1024], size = [256,128], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : i32 -> tensor<256x128xi32> + %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> + // offset = [0,6144], size = [256,128], stride = [0,6] + %14 = arith.addi %2, %scale7 : tensor<256x128xi32> + // offset = [512,6144], size = [256,128], stride = [1,6] + // mixed use + %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x128x!tt.ptr> + %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> + %19 = tt.load %18 : tensor<256x128x!tt.ptr> + tt.store %18, %19 : tensor<256x128x!tt.ptr> + %20 = arith.sitofp %14 : tensor<256x128xi32> to tensor<256x128xbf16> + tt.store %18, %20 : tensor<256x128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref {tt.tensor_kind = 2 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : i32 +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 512 : i32 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : i32 +// CHECK: %[[VAL_30:.*]] = tensor.empty() : tensor<256x128xi32> +// CHECK: %[[VAL_31:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_30]] : tensor<256x128xi32>) -> tensor<256x128xi32> +// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32> +// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) { +// CHECK: ^bb0(%[[VAL_10:.*]]: i32): +// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index +// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } -> tensor<256xi32> +// CHECK: %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_6]] : i32) outs(%[[VAL_8]] : tensor<256xi32>) -> tensor<256xi32> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_9]], %[[VAL_13]] : tensor<256xi32> +// CHECK: %[[VAL_15:.*]] = linalg.broadcast ins(%[[VAL_14]] : tensor<256xi32>) outs(%[[VAL_30]] : tensor<256x128xi32>) dimensions = [1] +// CHECK: %[[VAL_16:.*]] = tensor.empty() : tensor<128xi32> + +// CHECK: %[[VAL_20:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_16]] : tensor<128xi32>) { +// CHECK: ^bb0(%[[VAL_21:.*]]: i32): +// CHECK: %[[VAL_22:.*]] = linalg.index 0 : index +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : index to i32 +// CHECK: linalg.yield %[[VAL_23]] : i32 +// CHECK: } -> tensor<128xi32> +// CHECK: %[[VAL_24:.*]] = linalg.fill ins(%[[VAL_5]] : i32) outs(%[[VAL_16]] : tensor<128xi32>) -> tensor<128xi32> +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_20]], %[[VAL_24]] : tensor<128xi32> +// CHECK: %[[VAL_26:.*]] = linalg.broadcast ins(%[[VAL_25]] : tensor<128xi32>) outs(%[[VAL_30]] : tensor<256x128xi32>) dimensions = [0] +// CHECK: %[[VAL_32:.*]] = arith.muli %[[VAL_26]], %[[VAL_31]] : tensor<256x128xi32> +// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_15]], %[[VAL_32]] : tensor<256x128xi32> +// CHECK: %[[VAL_45:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, 6] : memref to memref<256x128xbf16, strided<[1, 6], offset: 6656>> +// CHECK: %[[VAL_46:.*]] = memref.alloc() : memref<256x128xbf16> +// CHECK: memref.copy %[[VAL_45]], %[[VAL_46]] : memref<256x128xbf16, strided<[1, 6], offset: 6656>> to memref<256x128xbf16> +// CHECK: %[[VAL_47:.*]] = bufferization.to_tensor %[[VAL_46]] restrict writable : memref<256x128xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_47]] in writable %[[VAL_45]] : (tensor<256x128xbf16>, memref<256x128xbf16, strided<[1, 6], offset: 6656>>) -> () +// CHECK: %[[VAL_48:.*]] = arith.sitofp %[[VAL_38]] : tensor<256x128xi32> to tensor<256x128xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_48]] in writable %[[VAL_45]] : (tensor<256x128xbf16>, memref<256x128xbf16, strided<[1, 6], offset: 6656>>) -> () +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/use_mid_chain.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/use_mid_chain.mlir new file mode 100644 index 000000000..d4285f50f --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/use_mid_chain.mlir @@ -0,0 +1,62 @@ +// RUN: triton-adapter-opt --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' %s | FileCheck %s +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : !tt.ptr, + %arg2 : !tt.ptr + ) + { + %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> + // offset = [512] size = 256, stride = 1 + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> + // offset = [512,0], size = [256,1], stride = [1,0] + %2 = tt.broadcast %1 : tensor<256x1xi32> -> tensor<256x128xi32> + // offset = [512,0], size = [256,128], stride = [1,0] + // mixed use + %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> + // offset = 1024, size = 128, stride = 1 + %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + // offset = [0,1024], size = [1,128], stride = [0,1] + %7 = tt.broadcast %6 : tensor<1x128xi32> -> tensor<256x128xi32> + // offset = [0,1024], size = [256,128], stride = [0,1] + %c6 = arith.constant 6 : i32 + %splat6 = tt.splat %c6 : i32 -> tensor<256x128xi32> + %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> + // offset = [0,6144], size = [256,128], stride = [0,6] + %14 = arith.addi %2, %scale7 : tensor<256x128xi32> + // offset = [512,6144], size = [256,128], stride = [1,6] + %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x128x!tt.ptr> + %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> + %19 = tt.load %18 : tensor<256x128x!tt.ptr> + tt.store %18, %19 : tensor<256x128x!tt.ptr> + %20 = tt.splat %arg2 : !tt.ptr -> tensor<256x128x!tt.ptr> + %21 = tt.addptr %20, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> + tt.store %21, %2 : tensor<256x128x!tt.ptr> + tt.return + } +} +// CHECK-LABEL: func.func @kernel( +// CHECK-SAME: %[[ARG_0:.*]]: memref, %[[ARG_1:.*]]: memref, +// CHECK-SAME: %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref {tt.tensor_kind = 2 : i32}, %[[VAL_2:.*]]: memref {tt.tensor_kind = 1 : i32}, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "local", mix_mode = "aiv"} { +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 512 : i32 +// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32> +// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) { +// CHECK: ^bb0(%[[VAL_10:.*]]: i32): +// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index +// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32 +// CHECK: linalg.yield %[[VAL_12]] : i32 +// CHECK: } -> tensor<256xi32> +// CHECK: %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_8]] : tensor<256xi32>) -> tensor<256xi32> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_9]], %[[VAL_13]] : tensor<256xi32> +// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<256x128xi32> +// CHECK: %[[VAL_16:.*]] = linalg.broadcast ins(%[[VAL_14]] : tensor<256xi32>) outs(%[[VAL_15]] : tensor<256x128xi32>) dimensions = [1] + +// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, 6] : memref to memref<256x128xbf16, strided<[1, 6], offset: 6656>> +// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<256x128xbf16> +// CHECK: memref.copy %[[VAL_19]], %[[VAL_20]] : memref<256x128xbf16, strided<[1, 6], offset: 6656>> to memref<256x128xbf16> +// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_20]] restrict writable : memref<256x128xbf16> +// CHECK: bufferization.materialize_in_destination %[[VAL_21]] in writable %[[VAL_19]] +// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, 6] : memref to memref<256x128xi32, strided<[1, 6], offset: 6656>> +// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] +// CHECK: return +// CHECK: } diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/view.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/view.mlir new file mode 100644 index 000000000..1679b1416 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/view.mlir @@ -0,0 +1,25 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +module { + tt.func public @auto_gen_kernel_01(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant dense<512> : tensor<1xi64> +// CHECK: %[[VAL_1:.*]] = arith.constant dense<[4, 4, 32]> : tensor<3xi64> + %c512_i32 = arith.constant 512 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c512_i32 : i32 + %2 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> + %3 = tt.splat %1 : i32 -> tensor<512xi32> + %4 = arith.addi %3, %2 : tensor<512xi32> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<512x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<512x!tt.ptr>, tensor<512xi32> + %7 = tt.load %6 evictionPolicy = evict_last : tensor<512x!tt.ptr> +// CHECK: %[[VAL_2:.*]] = tensor.reshape %[[VAL_0:.*]](%[[VAL_1:.*]]) : (tensor<512xf64>, tensor<3xi64>) -> tensor<4x4x32xf64> + %8 = tt.reshape %7 : tensor<512xf64> -> tensor<4x4x32xf64> + %9 = arith.addf %8, %8 : tensor<4x4x32xf64> +// CHECK: %[[VAL_7:.*]] = tensor.reshape %[[VAL_3:.*]](%[[VAL_4]]) : (tensor<4x4x32xf64>, tensor<1xi64>) -> tensor<512xf64> + %10 = tt.reshape %9 : tensor<4x4x32xf64> -> tensor<512xf64> + %11 = tt.splat %arg0 : !tt.ptr -> tensor<512x!tt.ptr> + %12 = tt.addptr %11, %4 : tensor<512x!tt.ptr>, tensor<512xi32> + tt.store %12, %10 evictionPolicy = evict_last : tensor<512x!tt.ptr> + tt.return + } +} diff --git a/third_party/ascend/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir b/third_party/ascend/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir new file mode 100644 index 000000000..091f470e1 --- /dev/null +++ b/third_party/ascend/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir @@ -0,0 +1,57 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s +// XFAIL: * +// We currently do not support this kind of modulo pattern: +// (a + arrange(0, K)) % M +module { + tt.func public @wrap_side_by_side_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %cst_0 = arith.constant dense<2> : tensor<4x1xi32> + %cst_1 = arith.constant dense<6> : tensor<4xi32> + %cst_2 = arith.constant dense<2> : tensor<4xi32> + %c4_i32 = arith.constant 4 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.addi %0, %cst_2 : tensor<4xi32> + %2 = tt.splat %arg3 : i32 -> tensor<4xi32> + %3 = arith.remsi %0, %2 : tensor<4xi32> + %4 = arith.addi %3, %cst_1 : tensor<4xi32> + %5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %6 = tt.splat %arg4 : i32 -> tensor<4x1xi32> + %7 = arith.muli %5, %6 : tensor<4x1xi32> + %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %9 = tt.splat %arg5 : i32 -> tensor<1x4xi32> + %10 = arith.muli %8, %9 : tensor<1x4xi32> + %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> + %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> + %13 = arith.addi %11, %12 : tensor<4x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %17 = tt.splat %arg6 : i32 -> tensor<4x1xi32> + %18 = arith.muli %17, %16 : tensor<4x1xi32> + %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> + %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> + %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %22 = tt.splat %arg7 : i32 -> tensor<1x4xi32> + %23 = arith.muli %22, %21 : tensor<1x4xi32> + %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> + %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> + %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> + %28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1> + %29 = arith.muli %arg4, %c4_i32 : i32 + %30 = tt.splat %29 : i32 -> tensor<4x4xi32> + %31 = arith.muli %arg5, %c4_i32 : i32 + %32 = tt.splat %31 : i32 -> tensor<4x4xi32> + %33:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { + %34 = tt.load %arg9, %28, %cst : tensor<4x4x!tt.ptr> + tt.store %arg10, %34 : tensor<4x4x!tt.ptr> + %35 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + %36 = tt.addptr %arg10, %32 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> + scf.yield %35, %36 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> + } + tt.return + } +} diff --git a/third_party/ascend/test/lit.cfg.py b/third_party/ascend/test/lit.cfg.py new file mode 100644 index 000000000..f2c263fdd --- /dev/null +++ b/third_party/ascend/test/lit.cfg.py @@ -0,0 +1,67 @@ +# -*- Python -*- + +import os +import platform +import re +import subprocess +import tempfile +import triton + +import lit.formats +import lit.util +from lit.llvm import llvm_config +from lit.llvm.subst import FindTool, ToolSubst + +# Configuration file for the 'lit' test runner + +# name: The name of this test suite +config.name = 'TRITON-ADAPTER' + +config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) + +# suffixes: A list of file extensions to treat as test files. +config.suffixes = ['.mlir'] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.triton_obj_root, 'test') +config.substitutions.append(('%PATH%', config.environment['PATH'])) + +config.quiet = False +config.show_suites = True +config.show_tests = True +config.show_uses = True + +llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) + +# excludes: A list of directories to exclude from the testsuite. The 'Inputs' +# subdirectories contain auxiliary inputs for various tests in their parent +# directories. +config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.triton_obj_root, 'test') +triton_tools_root = os.path.dirname(triton.__file__) +config.triton_tools_dir = os.path.join(triton_tools_root, 'backends', 'ascend') +config.filecheck_dir = config.llvm_tools_dir + +tool_dirs = [config.triton_tools_dir, config.llvm_tools_dir, config.filecheck_dir] + +# Tweak the PATH to include the tools dir. +for d in tool_dirs: + llvm_config.with_environment('PATH', d, append_path=True) +tools = [ + 'triton-adapter-opt', + ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), +] + +llvm_config.add_tool_substitutions(tools, tool_dirs) + +llvm_config.with_environment('PYTHONPATH', [ + os.path.join(config.mlir_binary_dir, 'python_packages', 'triton'), +], append_path=True) diff --git a/third_party/ascend/test/lit.site.cfg.py.in b/third_party/ascend/test/lit.site.cfg.py.in new file mode 100644 index 000000000..c614871fb --- /dev/null +++ b/third_party/ascend/test/lit.site.cfg.py.in @@ -0,0 +1,24 @@ +@LIT_SITE_CFG_IN_HEADER@ + +import sys + +config.triton_obj_root = "@TRITON_BINARY_DIR@" +config.triton_adapter_obj_root = "@TRITON_ADAPTER_BINARY_DIR@" +config.llvm_src_root = "@LLVM_SOURCE_DIR@" +config.llvm_obj_root = "@LLVM_BINARY_DIR@" +config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" +config.llvm_lib_dir = "@LLVM_LIBS_DIR@" +config.llvm_shlib_dir = "@CMAKE_LIBRARY_OUTPUT_DIRECTORY@" +config.llvm_shlib_ext = "@CMAKE_SHARED_LIBRARY_SUFFIX@" +config.llvm_exe_ext = "@EXEEXT@" +config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" +config.mlir_binary_dir = "@MLIR_BINARY_DIR@" +config.python_executable = "@Python3_EXECUTABLE@" +config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ + + +import lit.llvm +lit.llvm.initialize(lit_config, config) + +# Let the main config do the real work +lit_config.load_config(config, "@CMAKE_CURRENT_SOURCE_DIR@/lit.cfg.py") diff --git a/third_party/ascend/test/sglang/conftest.py b/third_party/ascend/test/sglang/conftest.py new file mode 100644 index 000000000..1971d4285 --- /dev/null +++ b/third_party/ascend/test/sglang/conftest.py @@ -0,0 +1,44 @@ +import os +import subprocess +from pathlib import Path +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--ptfile_path", + type=str, + default=None, + help="the test-file path (pt file)" + ) + + +@pytest.fixture(scope="function") +def ptfile_path(request): + filepath = request.config.getoption("--ptfile_path") + if not filepath: + # fatch default ptfile + test_case_name = request.node.name + remote_default_ptfile_url = f"https://triton-ascend-artifacts.obs.cn-southwest-2.myhuaweicloud.com/test/SGLang/{test_case_name}/default.pt" + local_path = Path(f"default_{test_case_name}.pt") + + # download ptfile + try: + subprocess.run( + ["curl", "-f", "-s", "-S", "-L", "-o", str(local_path), remote_default_ptfile_url], + check=True, + capture_output=True + ) + print(f"default ptfile is saved to: {local_path}") + except subprocess.CalledProcessError as e: + pytest.fail("download default ptfile error!") + + if not local_path.exists(): + pytest.fail(f"the {local_path} does not exist!") + + return str(local_path) + + if not os.path.exists(filepath): + pytest.fail(f"the {filepath} does not exist!") + + return filepath diff --git a/third_party/ascend/test/sglang/test_common.py b/third_party/ascend/test/sglang/test_common.py new file mode 100644 index 000000000..798d801e4 --- /dev/null +++ b/third_party/ascend/test/sglang/test_common.py @@ -0,0 +1,158 @@ +from typing import Optional +import torch +import torch_npu +import pytest + +eval_standard = { + torch.float32: { + "rtol": 1e-6, + "small_value": 1e-6, + "small_value_atol": 1e-9, + "etol": 1e-4, + }, + torch.float16: { + "rtol": 1e-3, + "small_value": 1e-3, + "small_value_atol": 1e-5, + "etol": 1e-3, + }, + torch.bfloat16: { + "rtol": 4e-3, + "small_value": 1e-3, + "small_value_atol": 1e-5, + "etol": 1e-3, + }, +} + + +def validate_cmp(dtype, y_cal, y_ref, overflow_mode: Optional[str] = None, device_type: Optional[str] = None): + if device_type is not None: + target_device = torch.device(device_type) + y_cal = y_cal.to(target_device) + y_ref = y_ref.to(target_device) + else: + y_cal=y_cal.npu() + y_ref=y_ref.npu() + if overflow_mode == "saturate": + if dtype in ['float32', 'float16']: + min_value = -torch.finfo(dtype).min + max_value = torch.finfo(dtype).max + elif dtype in ['int32', 'int16', 'int8']: + min_value = torch.iinfo(dtype).min + max_value = torch.iinfo(dtype).max + elif dtype == 'bool': + min_value = 0 + max_value = 1 + else: + raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype)) + y_ref = torch.clamp(y_ref, min=min_value, max=max_value) + if dtype == 'float16': + torch.testing.assert_close(y_ref, y_cal, rtol=5e-03, atol=5e-03, equal_nan=True) + elif dtype == 'bfloat16': + torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=5e-03, atol=5e-03, equal_nan=True) + elif dtype == 'float32': + torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) + elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': + assert torch.equal(y_cal, y_ref) + elif dtype == 'bool': + assert torch.equal(y_cal, y_ref) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + + +def convert_tensor_with_device_type(indata: dict, device_type: str): + target_device = torch.device(device_type) + outdata = {} + + for key, value in indata.items(): + if isinstance(value, torch.Tensor): + if value.device.type != target_device.type: + outdata[key] = value.to(target_device) + else: + outdata[key] = value + else: + outdata[key] = value + + return outdata + + +def compare_data_precision(dict_ref: dict, dict_cal: dict, device_type: str): + keys_ref, keys_cal = set(dict_ref.keys()), set(dict_cal.keys()) + if not keys_ref.issubset(keys_cal): + raise ValueError("The keys of dict_ref is not subset of dict_cal") + + for key in dict_ref.keys(): + val_a, val_b = dict_ref[key], dict_cal[key] + if type(val_a) != type(val_b): + raise ValueError("The data type of two dicts are different") + + if isinstance(val_a, torch.Tensor): + validate_cmp(dtype=str(val_a.dtype).split('.')[-1], y_ref=val_a, y_cal=val_b, device_type=device_type) + else: + raise ValueError("Non-tensor type is not currently supported") + + +# 验证浮点型数据的精度时,由于底层实现的差异,难以对比GPU、NPU +# 将NPU结果升级到fp32类型gold,在差异较小的情况下认为NPU结果act和GPU结果std精度通过 +# 特别是对于float16和bfloat16类型 +def verify_precision_by_gold_standard(gold: torch.Tensor, act: torch.Tensor, std: torch.tensor): + assert act.dtype == std.dtype, "standard tensor's dtype must equal to actual tensor's dtype!" + if act.dtype == torch.float16 or act.dtype == torch.float32 or act.dtype == torch.bfloat16: + assert gold.dtype == torch.float32, "golden should be f32" + assert not (torch.isnan(act).any() or torch.isinf(act).any()), "actual tensor can not have 'inf' or 'nan'" + + gold = gold.cpu() + act = act.cpu() + std = std.cpu() + + eps = eval_standard[act.dtype]['small_value'] + atol = eval_standard[act.dtype]['small_value_atol'] + + mask = torch.abs(gold) <= eps + small_count = mask.sum().item() + + def calculate_relative_errors_except_small(tensor): + re = torch.abs(gold - tensor) / torch.abs(gold) + return torch.where(mask, 0, re) + + act_re = calculate_relative_errors_except_small(act) + std_re = calculate_relative_errors_except_small(std) + act_ae = torch.abs(gold - std) + std_ae = torch.abs(gold - std) + + # 小值域的定义为golden小于某个阈值 eps + act_small_error_count = (mask & (act_ae > atol)).sum().item() + std_small_error_count = (mask & (std_ae > atol)).sum().item() + act_total = act.numel() + std_total = std.numel() + + act_small_error_ratio = act_small_error_count / act_total + std_small_error_ratio = std_small_error_count / std_total + + def calculate_rmse(tensor): + dlt2 = (tensor - gold) ** 2 + dlt2_except_small_mean = torch.where(mask, 0, dlt2).sum() / small_count + return torch.sqrt(dlt2_except_small_mean) + + act_rmse = calculate_rmse(act) + std_rmse = calculate_rmse(std) + + print(f"act_re.max = {act_re.max()}, std_re.max = {std_re.max()}, limit ratio = 10") + print(f"act_re.sum = {act_re.sum()}, std_re.sum = {std_re.sum()}, limit_ratio = 2") + print( + f"act_small_error_ratio = {act_small_error_ratio}, std_small_error_ratio = {std_small_error_ratio}, limit_ratio = 2") + print(f"act_rmse = {act_rmse}, std_rmse = {std_rmse}, limit_ratio = 2") + + # 条件 1:actual 与 golden 相对误差最大值超过 10 倍 standard 与 golden 相对误差最大值 + assert act_re.max() <= 10 * std_re.max(), "actual re max > stdandard re max's 10 times" + + # 条件 2:actual 与 golden 相对误差均值超过 2 倍 standard 与 golden 相对误差均值 + assert act_re.sum() <= 2 * std_re.sum(), "actual re sum > stdandard re sum's 2 times" + + # 条件 3:actual 小值域 ERROR 占比超过 standard 的两倍 + assert act_small_error_ratio <= 2 * std_small_error_ratio, "act_small_error_ratio > std_small_error_ratio 's 2 times" + + # 条件 4:actual 均方根误差差于 standard 的两倍 + assert act_rmse <= 2 * std_rmse, "act_rmse > std_rmse 's 2 times" + + return False \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test__fwd_kernel_ep_scatter_1.py b/third_party/ascend/test/sglang/v0.4.8/test__fwd_kernel_ep_scatter_1.py new file mode 100644 index 000000000..e502c6499 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__fwd_kernel_ep_scatter_1.py @@ -0,0 +1,60 @@ +import sys +import pytest +import torch + +import triton +import triton.language as tl + + +sys.path.append("..") +import test_common + + +#source /sglang/srt/layers/moe/ep_moe/kernels.py +@triton.jit +def _fwd_kernel_ep_scatter_1( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts: tl.constexpr, + BLOCK_E: tl.constexpr, + BLOCK_EXPERT_NUM: tl.constexpr, +): + cur_expert = tl.program_id(0) + + offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) + tokens_per_expert = tl.load( + num_recv_tokens_per_expert + offset_cumsum, + mask=offset_cumsum < num_experts, + other=0, + ) + cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert + tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) + + cur_expert_start = tl.load(expert_start_loc + cur_expert) + cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) + + m_indices_start_ptr = m_indices + cur_expert_start + off_expert = tl.arange(0, BLOCK_E) + + for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): + tl.store( + m_indices_start_ptr + start_m + off_expert, + cur_expert, + ) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + _fwd_kernel_ep_scatter_1[data['grid']](**input_data) + + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test__fwd_kernel_flash_decode_stage2.py b/third_party/ascend/test/sglang/v0.4.8/test__fwd_kernel_flash_decode_stage2.py new file mode 100644 index 000000000..930e7ab81 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__fwd_kernel_flash_decode_stage2.py @@ -0,0 +1,78 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\attention\triton_ops\double_sparsity_attention.py +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + block_n_size = ( + tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) + // BLOCK_SEQ + ) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + for block_seq_n in range(0, block_n_size, 1): + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + new_max_logic = tl.maximum(tlogic, max_logic) + + old_scale = tl.exp(max_logic - new_max_logic) + acc *= old_scale + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp = sum_exp * old_scale + exp_logic + max_logic = new_max_logic + + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) + return + + +def test__fwd_kernel_flash_decode_stage2(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _fwd_kernel_flash_decode_stage2[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test__gate_up_lora_b_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test__gate_up_lora_b_kernel.py new file mode 100644 index 000000000..085051784 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__gate_up_lora_b_kernel.py @@ -0,0 +1,102 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/lora/triton_ops/gate_up_lora_b.py +@triton.jit +def _gate_up_lora_b_kernel( + x, + weights, + output, + K, + output_dim, + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + fuse_scaling_add, + scalings, +): + batch_id = tl.program_id(axis=2) + gate_up_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + n_start = gate_up_id * output_dim + rank = tl.load(lora_ranks + w_index) + scaling = tl.load(scalings + w_index) + + K = tl.minimum(K, rank) + + num_pid_n = tl.cdiv(output_dim, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + + for_num = tl.cdiv(K, BLOCK_K) + k = 0 + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) + and (n_offset[None, :] < output_dim), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def test__gate_up_lora_b_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _gate_up_lora_b_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test__per_token_group_quant_int8.py b/third_party/ascend/test/sglang/v0.4.8/test__per_token_group_quant_int8.py new file mode 100644 index 000000000..212bf9906 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__per_token_group_quant_int8.py @@ -0,0 +1,72 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\quantization\int8_kernel.py +@triton.jit +def _per_token_group_quant_int8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + # Stride of input + y_stride, + # Columns of input + N, + # Avoid to divide zero + eps, + # Information for int8 + int8_min, + int8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group quantization on a + tensor. + + This function converts the tensor values into int8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / int8_max + y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def test__per_token_group_quant_int8(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + _per_token_group_quant_int8[data["grid"]](**input_data) + + ref_s = data['gpu_output']['y_s_ptr'].cpu() + cal_s = input_data['y_s_ptr'].cpu() + torch.testing.assert_close(ref_s, cal_s, rtol=5e-03, atol=5e-03, equal_nan=True) + + #ensure the difference of y_q_ptr no more than 1 + ref_q = data['gpu_output']['y_q_ptr'].cpu() + cal_q = input_data['y_q_ptr'].cpu() + diff = torch.abs(ref_q - cal_q) + diff_over_1 = diff > 1 + assert diff_over_1.sum().item() == 0 \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test__per_token_quant_int8.py b/third_party/ascend/test/sglang/v0.4.8/test__per_token_quant_int8.py new file mode 100644 index 000000000..ad5e3419f --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__per_token_quant_int8.py @@ -0,0 +1,56 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\quantization\int8_kernel.py +@triton.jit +def _per_token_quant_int8( + x_ptr, + xq_ptr, + scale_ptr, + x_sum_ptr, + stride_x, + stride_xq, + N, + CAL_SUM: tl.constexpr, + BLOCK: tl.constexpr, +): + row_id = tl.program_id(0) + + cols = tl.arange(0, BLOCK) + mask = cols < N + + x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) + scale_x = absmax / 127 + x_q = x * (127 / absmax) + x_q = tl.extra.ascend.libdevice.round(x_q).to(tl.int8) + if CAL_SUM: + x_sum = tl.sum(x, axis=0) + tl.store(x_sum_ptr + row_id, x_sum.to(x_sum_ptr.dtype.element_ty)) + + tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) + tl.store(scale_ptr + row_id, scale_x.to(scale_ptr.dtype.element_ty)) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _per_token_quant_int8[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test__sgemm_lora_a_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test__sgemm_lora_a_kernel.py new file mode 100644 index 000000000..af98ed8dd --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__sgemm_lora_a_kernel.py @@ -0,0 +1,106 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +import os + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/lora/triton_ops/sgemm_lora_a.py +@triton.jit +def _sgemm_lora_a_kernel( + x, + weights, + output, + N, + K, + stack_num, + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + out_x_tile, + out_w_tile, + out_tmp_dot, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + batch_id = tl.program_id(axis=1) + w_index = tl.load(weight_indices + batch_id) + rank = tl.load(lora_ranks + w_index) + if rank == 0: + return + + pid = tl.program_id(axis=0) + seg_start = tl.load(seg_indptr + batch_id) + seg_len = tl.load(seg_lens + batch_id) + N = tl.minimum(N, rank * stack_num) + + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N), + other=0.0, + ) + + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def test__sgemm_lora_b_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _sgemm_lora_a_kernel[data["grid"]](**input_data) + + gpu_output = data['gpu_output']['output'] + npu_output = input_data['output'].cpu() + # compare the results of GPU and NPU. + try: + torch.testing.assert_close(gpu_output, npu_output, rtol=0.2, atol=0.02, equal_nan=True) + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test__sgemm_lora_b_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test__sgemm_lora_b_kernel.py new file mode 100644 index 000000000..2f7e72a56 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__sgemm_lora_b_kernel.py @@ -0,0 +1,138 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/lora/triton_ops/sgemm_lora_b.py +@triton.jit +def _sgemm_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Matrix dimensions + N, # output_dim + K, # r + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # For fused output scaling + scalings, +): + """ + Computes a segmented batched matrix multiplication for the LoRA B matrix + and adds the result to the output in-place. + + When a sequence's rank is 0, the kernel is essentially a no-op, following + the convention in pytorch where the product of two matrices of shape (m, 0) + and (0, n) is an all-zero matrix of shape (m, n). + + Args: + x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication, + of shape `(s, K)`, where `s` is the total number of tokens. + weights (torch.Tensor): The LoRA 'B' weights for all available adapters, + with shape `(num_lora, N, K)`. + output (torch.Tensor): The output tensor of shape `(s, N)`. This can be + the base model's output for a fused add operation. + """ + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + batch_id = tl.program_id(axis=1) + w_index = tl.load(weight_indices + batch_id) + rank = tl.load(lora_ranks + w_index) + + # If rank is 0, this kernel is a no-op. + if rank == 0: + return + + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + scaling = tl.load(scalings + w_index) + # Adjust K (rank) according to the specific LoRA adapter + K = tl.minimum(K, rank) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iterate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = s_offset[:, None] < seg_len + partial_sum += tl.load(output_ptr, mask=output_mask) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def test__sgemm_lora_b_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _sgemm_lora_b_kernel[data["grid"]](**input_data) + + gpu_output = data['gpu_output']['output'] + npu_output = input_data['output'].cpu() + # compare the results of GPU and NPU. + try: + torch.testing.assert_close(gpu_output, npu_output, rtol=1.02, atol=0.02, equal_nan=True) + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test__sparse_fwd_kernel_flash_decode_stage1.py b/third_party/ascend/test/sglang/v0.4.8/test__sparse_fwd_kernel_flash_decode_stage1.py new file mode 100644 index 000000000..33379e73d --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__sparse_fwd_kernel_flash_decode_stage1.py @@ -0,0 +1,105 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common +REDUCE_TRITON_TYPE = tl.float32 + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +# source: python\sglang\srt\layers\attention\triton_ops\double_sparsity_attention.py +@triton.jit +def _sparse_fwd_kernel_flash_decode_stage1( # Double Sparsity's approximate attention + Q_Label, + K_Label_Buffer, + sm_scale, + Req_to_tokens, # shape: [B, S] + B_Seqlen, + Att_Out, # shape: [H, B, S] easier for topk + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + att_stride_h, + att_stride_b, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + logit_cap: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + cur_batch_start_index = 0 + cur_batch_end_index = cur_batch_seq_len + + min_val = -float("inf") + att_value = tl.full([BLOCK_N], min_val, dtype=tl.float32) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_index = start_n * BLOCK_N + block_mask = tl.where(block_index < cur_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q_Label + off_q + start_mark).to(REDUCE_TRITON_TYPE) + offs_n_new = cur_batch_start_index + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + offs_buf_k = ( + k_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) + k = tl.load( + K_Label_Buffer + offs_buf_k, + mask=offs_n_new[:, None] < cur_batch_end_index, + other=0.0, + ).to(REDUCE_TRITON_TYPE) + + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + + if logit_cap > 0: + att_value = logit_cap * tanh(att_value / logit_cap) + + att_value = tl.where(offs_n < cur_batch_end_index, att_value, min_val) + off_o = cur_head * att_stride_h + (cur_batch * att_stride_b + offs_n) + tl.store(Att_Out + off_o, att_value) + + +def test__sparse_fwd_kernel_flash_decode_stage1(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _sparse_fwd_kernel_flash_decode_stage1[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test__sparse_fwd_kernel_flash_decode_stage2.py b/third_party/ascend/test/sglang/v0.4.8/test__sparse_fwd_kernel_flash_decode_stage2.py new file mode 100644 index 000000000..d4a51bfbb --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__sparse_fwd_kernel_flash_decode_stage2.py @@ -0,0 +1,137 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\attention\triton_ops\double_sparsity_attention.py +@triton.jit +def _sparse_fwd_kernel_flash_decode_stage2( + Q, + K, + V, + sm_scale, + Req_to_tokens, # shape: [B, S] + Topk_token_indices, # shape: [H, B, k] + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + Heavy_token_num, # NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future + stride_req_to_tokens_b, + stride_topk_token_indices_h, + stride_topk_token_indices_b, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_o_eb, + stride_mid_o_eh, + gqa_group_size, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + cur_kv_head = cur_head // gqa_group_size + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_end_index = tl.minimum(Heavy_token_num, cur_batch_start_index + BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + + offs_n = tl.arange(0, BLOCK_N) + + q = tl.load(Q + off_q) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(cur_batch_start_index, cur_batch_end_index, BLOCK_N): + offs_n_new = start_n + offs_n + topk_token_indices = tl.load( + Topk_token_indices + + stride_topk_token_indices_h * cur_head + + stride_topk_token_indices_b * cur_batch + + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch + topk_token_indices, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] + k = tl.load( + K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0 + ) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) + v = tl.load( + V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0 + ) + + cur_max_logic = tl.max(att_value, axis=0) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale + acc += tl.sum(exp_logic[:, None] * v, axis=0) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) + max_logic = new_max_logic + + need_store = 1 + for _ in range(0, need_store, 1): + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + seq_start_block * stride_mid_os + + offs_d + ) + off_mid_o_logexpsum = ( + cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block + ) + tl.store(Mid_O + off_mid_o, acc / sum_exp) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) + return + + +def test__sparse_fwd_kernel_flash_decode_stage2(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _sparse_fwd_kernel_flash_decode_stage2[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test__w8a8_block_int8_matmul.py b/third_party/ascend/test/sglang/v0.4.8/test__w8a8_block_int8_matmul.py new file mode 100644 index 000000000..385dbcbc1 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test__w8a8_block_int8_matmul.py @@ -0,0 +1,112 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\quantization\int8_kernel.py +@triton.jit +def _w8a8_block_int8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def test__w8a8_block_int8_matmul(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _w8a8_block_int8_matmul[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + # the precision of tl.dot in GPU is lower than that in NPU + try: + torch.testing.assert_close(data["gpu_output"]['C'], input_data['C'].cpu(), rtol=1e-02, atol=1e-02, equal_nan=True) + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_align_evict_mask_to_page_size.py b/third_party/ascend/test/sglang/v0.4.8/test_align_evict_mask_to_page_size.py new file mode 100644 index 000000000..7c68ddbe7 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_align_evict_mask_to_page_size.py @@ -0,0 +1,55 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/speculative/eagle_utils.py +@triton.jit +def align_evict_mask_to_page_size( + seq_lens, + evict_mask, + page_size: tl.constexpr, + num_draft_tokens: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + t_range = tl.arange(0, BLOCK_SIZE) + + bid = tl.program_id(axis=0) + seq_len = tl.load(seq_lens + bid) + io_mask = t_range < num_draft_tokens + mask_row = tl.load( + evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0 + ) + + num_trues = tl.sum(mask_row) + num_false = num_draft_tokens - num_trues + + start = (seq_len + num_false - 1) // page_size * page_size - seq_len + for i in range(max(start, 0), min(start + page_size, num_draft_tokens)): + tl.store(evict_mask + bid * num_draft_tokens + i, False) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # [gpu_output] (dict): + # [grid] : + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + align_evict_mask_to_page_size[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py new file mode 100644 index 000000000..17582127d --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py @@ -0,0 +1,74 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/mem_cache/allocator.py +@triton.jit +def alloc_decode_kernel( + seq_lens_ptr, + last_loc_ptr, + free_page_ptr, + out_indices, + ret_values, + bs_upper: tl.constexpr, + page_size: tl.constexpr, +): + pid = tl.program_id(0) + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid) + pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens) + + seq_len = tl.load(seq_lens_ptr + pid) + pre_len = seq_len - 1 + + num_pages_after = (seq_lens + page_size - 1) // page_size + num_pages_before = (pre_lens + page_size - 1) // page_size + num_new_pages = num_pages_after - num_pages_before + + num_page_start_loc_self = (seq_len + page_size - 1) // page_size - ( + pre_len + page_size - 1 + ) // page_size + sum_num_new_pages = tl.sum(num_new_pages) + new_page_start_loc = sum_num_new_pages - num_page_start_loc_self + + # Return value + if pid == tl.num_programs(0) - 1: + tl.store(ret_values, sum_num_new_pages) + + if num_page_start_loc_self == 0: + last_loc = tl.load(last_loc_ptr + pid) + tl.store(out_indices + pid, last_loc + 1) + else: + page = tl.load(free_page_ptr + new_page_start_loc) + tl.store(out_indices + pid, page * page_size) + + +def test_alloc_decode_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + alloc_decode_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_alloc_extend_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_alloc_extend_kernel.py new file mode 100644 index 000000000..bd81850c0 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_alloc_extend_kernel.py @@ -0,0 +1,113 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +@triton.jit +def alloc_extend_kernel( + pre_lens_ptr, + seq_lens_ptr, + last_loc_ptr, + free_page_ptr, + out_indices, + ret_values, + bs_upper: tl.constexpr, + page_size: tl.constexpr, + max_num_extend_tokens: tl.constexpr, +): + pid = tl.program_id(0) + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid) + pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid) + extend_lens = seq_lens - pre_lens + + seq_len = tl.load(seq_lens_ptr + pid) + pre_len = tl.load(pre_lens_ptr + pid) + extend_len = seq_len - pre_len + + sum_extend_lens = tl.sum(extend_lens) + output_start_loc = sum_extend_lens - extend_len + + num_pages_after = (seq_lens + page_size - 1) // page_size + num_pages_before = (pre_lens + page_size - 1) // page_size + num_new_pages = num_pages_after - num_pages_before + + num_page_start_loc_self = (seq_len + page_size - 1) // page_size - ( + pre_len + page_size - 1 + ) // page_size + sum_num_new_pages = tl.sum(num_new_pages) + new_page_start_loc = sum_num_new_pages - num_page_start_loc_self + + # Return value + if pid == tl.num_programs(0) - 1: + merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to( + tl.int64 + ) + tl.store(ret_values, merged_value) + + # Part 1: fill the old partial page + last_loc = tl.load(last_loc_ptr + pid) + num_part1 = ( + min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len + ) + offset_one_page = tl.arange(0, page_size) + tl.store( + out_indices + output_start_loc + offset_one_page, + last_loc + 1 + offset_one_page, + mask=offset_one_page < num_part1, + ) + if pre_len + num_part1 == seq_len: + return + + # Part 2: fill the new full pages + num_part2 = ( + seq_len // page_size * page_size + - (pre_len + page_size - 1) // page_size * page_size + ) + + offset_many_page = tl.arange(0, max_num_extend_tokens) + page_start = tl.load( + free_page_ptr + new_page_start_loc + offset_many_page // page_size, + mask=offset_many_page < num_part2, + ) + tl.store( + out_indices + output_start_loc + num_part1 + offset_many_page, + page_start * page_size + offset_many_page % page_size, + mask=offset_many_page < num_part2, + ) + if pre_len + num_part1 + num_part2 == seq_len: + return + + # Part 3: fill the new partial page + num_part3 = seq_len - seq_len // page_size * page_size + start_loc = tl.load( + free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1 + ) + tl.store( + out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page, + start_loc * page_size + offset_one_page, + mask=offset_one_page < num_part3, + ) + + +def test_alloc_extend_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + alloc_extend_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_apply_token_bitmask_inplace_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_apply_token_bitmask_inplace_kernel.py new file mode 100644 index 000000000..38f28d731 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_apply_token_bitmask_inplace_kernel.py @@ -0,0 +1,104 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\constrained\triton_ops\bitmask_ops.py +@triton.jit +def apply_token_bitmask_inplace_kernel( + logits_ptr, + bitmask_ptr, + indices_ptr, + num_rows, + vocab_size, + logits_strides, + bitmask_strides, + NUM_SMS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor, + where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask, + the masked logits will be set to -inf. + + Parameters + ---------- + logits_ptr : tl.tensor + Pointer to the logits tensor to apply the bitmask to. + + bitmask_ptr : tl.tensor + Pointer to the bitmask tensor to apply. + + indices_ptr : Optional[tl.tensor] + Optional pointer to indices tensor specifying which rows to apply the mask to. + + num_rows : int + Number of rows to process. If indices_ptr is provided, this is the number of unique indices. + + vocab_size : int + Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the + same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary. + + logits_strides : int + Stride between rows in the logits tensor. + + bitmask_strides : int + Stride between rows in the bitmask tensor. + + NUM_SMS : int + Number of streaming multiprocessors to use. + + BLOCK_SIZE : int + Size of processing blocks. + """ + + pid = tl.program_id(0) + num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE) + for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS): + row_id = work_id // num_blocks + block_offset = (work_id % num_blocks) * BLOCK_SIZE + batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id) + offsets = block_offset + tl.arange(0, BLOCK_SIZE) + bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32) + vocab_mask = offsets < vocab_size + packed_bitmask_mask = bitmask_offsets < bitmask_strides + packed_bitmask = tl.load( + bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, + packed_bitmask_mask, + ) + bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0 + bitmask = bitmask.reshape(BLOCK_SIZE) + + tl.store( + logits_ptr + batch_id * logits_strides + offsets, + -float("inf"), + vocab_mask & bitmask, + ) + + +def test_apply_token_bitmask_inplace_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + apply_token_bitmask_inplace_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_assign_draft_cache_locs.py b/third_party/ascend/test/sglang/v0.4.8/test_assign_draft_cache_locs.py new file mode 100644 index 000000000..6f377e3fa --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_assign_draft_cache_locs.py @@ -0,0 +1,110 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +#source: python\sglang\srt\speculative\eagle_utils.py + + +@triton.jit +def assign_draft_cache_locs( + req_pool_indices, + req_to_token, + seq_lens, + extend_lens, + num_new_pages_per_topk, + out_cache_loc, + pool_len: tl.constexpr, + topk: tl.constexpr, + speculative_num_steps: tl.constexpr, + page_size: tl.constexpr, + bs_upper: tl.constexpr, + iter_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + pid = tl.program_id(axis=0) + + if page_size == 1 or topk == 1: + copy_len = topk * speculative_num_steps + out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps + else: + bs_offset = tl.arange(0, bs_upper) + copy_len = tl.load(extend_lens + pid) + cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid)) + out_cache_ptr = out_cache_loc + cum_copy_len + + # Part 1: Copy from out_cache_loc to req_to_token + kv_start = tl.load(seq_lens + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + num_loop = tl.cdiv(copy_len, BLOCK_SIZE) + for i in range(num_loop): + copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = copy_offset < copy_len + data = tl.load(out_cache_ptr + copy_offset, mask=mask) + tl.store(token_pool + kv_start + copy_offset, data, mask=mask) + + if page_size == 1 or topk == 1: + return + + # Part 2: Copy the indices for the last partial page + prefix_len = tl.load(seq_lens + pid) + last_page_len = prefix_len % page_size + offsets = tl.arange(0, page_size) + mask = offsets < last_page_len + num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid) + prefix_base = token_pool + prefix_len - last_page_len + + for topk_id in range(topk): + value = tl.load(prefix_base + offsets, mask=mask) + tl.store( + prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets, + value, + mask=mask, + ) + + # Part 3: Remove the padding in out_cache_loc + iter_offest = tl.arange(0, iter_upper) + for topk_id in range(topk): + indices = tl.load( + prefix_base + + topk_id * num_new_pages_per_topk_ * page_size + + last_page_len + + iter_offest, + mask=iter_offest < speculative_num_steps, + ) + tl.store( + out_cache_loc + + pid * topk * speculative_num_steps + + topk_id * speculative_num_steps + + iter_offest, + indices, + mask=iter_offest < speculative_num_steps, + ) + + +def test_assign_draft_cache_locs(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + assign_draft_cache_locs[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_assign_req_to_token_pool.py b/third_party/ascend/test/sglang/v0.4.8/test_assign_req_to_token_pool.py new file mode 100644 index 000000000..586bac253 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_assign_req_to_token_pool.py @@ -0,0 +1,61 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/speculative/eagle_utils.py +@triton.jit +def assign_req_to_token_pool( + req_pool_indices, + req_to_token, + start_offset, + end_offset, + out_cache_loc, + pool_len: tl.constexpr, + bs_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(start_offset + pid) + kv_end = tl.load(end_offset + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + + length_offset = tl.arange(0, bs_upper) + start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0) + end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0) + out_offset = tl.sum(end - start, axis=0) + + out_cache_ptr = out_cache_loc + out_offset + + save_offset = tl.arange(0, BLOCK_SIZE) + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) + save_offset += BLOCK_SIZE + load_offset += BLOCK_SIZE + + +def test_assign_req_to_token_pool(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + assign_req_to_token_pool[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_compute_m_num_tiles_indptr.py b/third_party/ascend/test/sglang/v0.4.8/test_compute_m_num_tiles_indptr.py new file mode 100644 index 000000000..63d67a1ce --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_compute_m_num_tiles_indptr.py @@ -0,0 +1,37 @@ +import sys +import pytest +import torch + +import triton +import triton.language as tl + +sys.path.append("..") +import test_common + + +#source python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def compute_m_num_tiles_indptr( + m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr +): + for bs in range(batch_size): + m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs) + cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M) + pre_num_tiles = tl.load(m_num_tiles_indptr + bs) + tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + compute_m_num_tiles_indptr[data['grid']](**input_data) + + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_compute_masked_m_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_compute_masked_m_triton_kernel.py new file mode 100644 index 000000000..a9201f8d7 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_compute_masked_m_triton_kernel.py @@ -0,0 +1,34 @@ +import sys +import pytest +import torch + +import triton +import triton.language as tl + +sys.path.append("..") +import test_common + + +#source python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def compute_masked_m_triton_kernel(seg_indptr, masked_m): + expert_id = tl.program_id(0) + start = tl.load(seg_indptr + expert_id) + end = tl.load(seg_indptr + expert_id + 1) + tl.store(masked_m + expert_id, (end - start)) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + compute_masked_m_triton_kernel[data['grid']](**input_data) + + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_compute_position_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_compute_position_kernel.py new file mode 100644 index 000000000..3cbbbedfa --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_compute_position_kernel.py @@ -0,0 +1,64 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +#source: python\sglang\srt\model_executor\forward_batch_info.py + + +@triton.jit +def compute_position_kernel( + positions, + extend_start_loc, + extend_prefix_lens, + extend_seq_lens, + has_prefix: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(0).to(tl.int64) + + prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0 + seq_len = tl.load(extend_seq_lens + pid) + + # NOTE: This can be slow for large bs + cumsum_start = tl.cast(0, tl.int64) + for i in range(pid): + cumsum_start += tl.load(extend_seq_lens + i) + + num_loop = tl.cdiv(seq_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + tl.store( + positions + cumsum_start + offset, + prefix_len + offset, + mask=offset < seq_len, + ) + tl.store(extend_start_loc + pid, cumsum_start) + + +def test_compute_position_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + compute_position_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_compute_seg_indptr_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_compute_seg_indptr_triton_kernel.py new file mode 100644 index 000000000..c54a2ccc4 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_compute_seg_indptr_triton_kernel.py @@ -0,0 +1,43 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): + expert = tl.program_id(0) + low = 0 + high = num_toks - 1 + target_location = -1 + while low <= high: + mid = (low + high) // 2 + + if tl.load(reorder_topk_ids + mid) > expert: + high = mid - 1 + else: + low = mid + 1 + target_location = mid + tl.store(seg_indptr + expert + 1, target_location + 1) + + +def test_compute_seg_indptr_triton_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + compute_seg_indptr_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_compute_src2dst_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_compute_src2dst_triton_kernel.py new file mode 100644 index 000000000..0d61387cf --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_compute_src2dst_triton_kernel.py @@ -0,0 +1,37 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + tl.store(src2dst + src_id, dst_id, mask=mask) + + +def test_compute_src2dst_triton_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + compute_src2dst_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") + \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_context_fwd_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_context_fwd_kernel.py new file mode 100644 index 000000000..37e68e649 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_context_fwd_kernel.py @@ -0,0 +1,168 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\attention\triton_ops\prefill_attention.py +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + Out, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + Lk: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] + + mask_d = offs_d < Lk + + q = tl.load( + Q + off_q, + mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]), + other=0.0, + ) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + end_n = ( + cur_batch_seq_len + if not IS_CAUSAL + else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) + ) + for start_n in range(0, block_mask * end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]), + other=0.0, + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if IS_CAUSAL: + qk += tl.where( + (start_n + offs_n[None, :] < cur_batch_seq_len) + & (offs_m[:, None] >= (start_n + offs_n[None, :])), + 0, + float("-inf"), + ) + else: + qk += tl.where( + (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf") + ) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]), + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] + ) + out_ptrs = Out + off_o + tl.store( + out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]) + ) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _fwd_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_copy_all_layer_kv_cache.py b/third_party/ascend/test/sglang/v0.4.8/test_copy_all_layer_kv_cache.py new file mode 100644 index 000000000..2d7103aff --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_copy_all_layer_kv_cache.py @@ -0,0 +1,105 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\mem_cache\memory_pool.py +@triton.jit +def copy_all_layer_kv_cache( + data_ptrs, + strides, + tgt_loc_ptr, + src_loc_ptr, + num_locs, + num_locs_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + + bid = tl.program_id(0) + stride = tl.load(strides + bid) + + data_ptr = tl.load(data_ptrs + bid) + data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8)) + + num_locs_offset = tl.arange(0, num_locs_upper) + tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs) + src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs) + + # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks + # because this copy is an inplace operation. + + num_loop = tl.cdiv(stride, BLOCK_SIZE) + for i in range(num_loop): + copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :] + value = tl.load( + data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask + ) + tl.store( + data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :], + value, + mask=mask, + ) + + +def test_copy_all_layer_kv_cache(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (2,) + + input_data = data['input_data'] + device_type = 'npu' + + num_layers = input_data['num_layers'] + max_seq_len = input_data['max_seq_len'] + hidden_size = input_data['hidden_size'] + num_locs = input_data['num_locs'] + num_locs_upper = input_data['num_locs_upper'] + + src_loc = input_data['src_loc_ptr'].to(device_type) + tgt_loc = input_data['tgt_loc_ptr'].to(device_type) + + strides = input_data['strides'].to(device_type) + + kv_caches = [] + data_ptrs = [] + + for i in range(num_layers): + cache = torch.zeros((max_seq_len, hidden_size), dtype=torch.float16).to(device_type) + for pos in range(max_seq_len): + cache[pos, :] = (i * 100 + pos) * torch.ones(hidden_size, dtype=torch.float16).to(device_type) + kv_caches.append(cache) + data_ptrs.append(cache.data_ptr()) + + data_ptrs_tensor = torch.tensor(data_ptrs, dtype=torch.uint64).to(device_type) + + copy_all_layer_kv_cache[data['grid']]( + data_ptrs=data_ptrs_tensor, + strides=strides, + tgt_loc_ptr=tgt_loc, + src_loc_ptr=src_loc, + num_locs=num_locs, + num_locs_upper=num_locs_upper, + ) + + # compare the results of GPU and NPU. + gpu_out = {'kv_caches': torch.cat(data['gpu_output']['kv_caches'], dim=0)} + npu_out = {'kv_caches': torch.cat(kv_caches, dim=0)} + try: + test_common.compare_data_precision(gpu_out, npu_out, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_create_chunked_prefix_cache_kv_indices.py b/third_party/ascend/test/sglang/v0.4.8/test_create_chunked_prefix_cache_kv_indices.py new file mode 100644 index 000000000..cba1016df --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_create_chunked_prefix_cache_kv_indices.py @@ -0,0 +1,71 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\model_executor\forward_batch_info.py +@triton.jit +def create_chunked_prefix_cache_kv_indices( + req_to_token_ptr, # (max_batch, max_context_len,) + req_pool_indices_ptr, # (batch_size,) + chunk_start_idx_ptr, # (batch_size,) + chunk_seq_lens_ptr, # (batch_size,) + chunk_cu_seq_lens_ptr, # (batch_size + 1,) + chunk_kv_indices_ptr, # (num_chunk_tokens,) + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + + # find the req pool idx, this is for batch to token + req_pool_index = tl.load(req_pool_indices_ptr + pid) + chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid) + + # get the token position of current chunk + chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32) + chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32) + + num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arrange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < chunk_seq_len + data = tl.load( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + chunk_start_pos + + offset, + mask=mask, + ) + tl.store( + chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask + ) + + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (3,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + create_chunked_prefix_cache_kv_indices[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_create_extend_after_decode_spec_info.py b/third_party/ascend/test/sglang/v0.4.8/test_create_extend_after_decode_spec_info.py new file mode 100644 index 000000000..73e4e2834 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_create_extend_after_decode_spec_info.py @@ -0,0 +1,58 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +# source : python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def create_extend_after_decode_spec_info( + verified_id_ptr, + seq_lens_ptr, + accept_lens_ptr, + positions_ptr, + new_verified_id_ptr, + bs_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + offsets = tl.arange(0, bs_upper) + seq_length = tl.load(seq_lens_ptr + pid) + accept_length = tl.load(accept_lens_ptr + pid) + + accept_len_cumsum = tl.sum( + tl.load(accept_lens_ptr + offsets, mask=offsets < pid, other=0) + ) + positions_ptr = positions_ptr + accept_len_cumsum + mask = offsets < accept_length + tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask) + + accept_len_cumsum += accept_length - 1 + verified_id_data = tl.load(verified_id_ptr + accept_len_cumsum) + tl.store(new_verified_id_ptr + pid, verified_id_data) + + +def test_create_extend_after_decode_spec_info(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + create_extend_after_decode_spec_info[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_create_flashinfer_kv_indices_triton.py b/third_party/ascend/test/sglang/v0.4.8/test_create_flashinfer_kv_indices_triton.py new file mode 100644 index 000000000..829aeffe3 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_create_flashinfer_kv_indices_triton.py @@ -0,0 +1,72 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\attention\utils.py +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + kv_indices_ptr, + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + + # find the req pool idx, this is for batch to token + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for i in range(num_loop): + # index into req_to_token_ptr needs to be int64 + offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE + mask = offset < kv_end - kv_start + data = tl.load( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + kv_start + + offset, + mask=mask, + ) + tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) + + +def test_create_flashinfer_kv_indices_triton(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + create_flashinfer_kv_indices_triton[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_create_flashmla_kv_indices_triton.py b/third_party/ascend/test/sglang/v0.4.8/test_create_flashmla_kv_indices_triton.py new file mode 100644 index 000000000..63dc1d4c7 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_create_flashmla_kv_indices_triton.py @@ -0,0 +1,86 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\attention\utils.py +@triton.jit +def create_flashmla_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_start_idx, + kv_indices_ptr, + req_to_token_ptr_stride: tl.constexpr, + kv_indices_ptr_stride: tl.constexpr, + PAGED_SIZE: tl.constexpr = 64, +): + BLOCK_SIZE: tl.constexpr = 4096 + NUM_PAGE_PER_BLOCK: tl.constexpr = 64 + pid = tl.program_id(axis=0) + + # find the req pool idx, this is for batch to token + req_pool_index = tl.load(req_pool_indices_ptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE) + num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + + for i in range(num_pages_loop): + # index into req_to_token_ptr needs to be int64 + paged_offset = ( + tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK + ) * PAGED_SIZE + paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK + + mask = paged_offset < num_paged * PAGED_SIZE + mask_out = paged_offset_out < num_paged + + data = tl.load( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + kv_start + + paged_offset, + mask=mask, + ) + tl.store( + kv_indices_ptr + pid * kv_indices_ptr_stride + paged_offset_out, + data // PAGED_SIZE, + mask=mask_out, + ) + + +def test_create_flashmla_kv_indices_triton(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + create_flashmla_kv_indices_triton[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_decode_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_decode_kernel.py new file mode 100644 index 000000000..cfa7b1c0a --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_decode_kernel.py @@ -0,0 +1,94 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: sgl-kernel/benchmark/bench_lightning_attention_decode.py +@triton.jit +def _decode_kernel( + Q, + K, + V, + KV, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + d_original: tl.constexpr, + e: tl.constexpr, + e_original: tl.constexpr, +): + off_bh = tl.program_id(0) + off_h = off_bh % h + + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + kv_offset = off_bh * d * e + + s = tl.load(S + off_h) + ratio = tl.exp(-s) + + d_idx = tl.arange(0, d) + e_idx = tl.arange(0, e) + + # Create masks for original dimensions + d_mask = d_idx < d_original + e_mask = e_idx < e_original + + # Load with masking + q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0) + k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0) + v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0) + + # Load KV with 2D masking + kv = tl.load( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + mask=(d_mask[:, None] & e_mask[None, :]), + other=0.0, + ) + + # Compute outer product using element-wise operations + k_v_prod = k[:, None] * v[None, :] + kv = ratio * kv + k_v_prod + + # Store KV with 2D masking + tl.store( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + kv.to(KV.dtype.element_ty), + mask=(d_mask[:, None] & e_mask[None, :]), + ) + + # Compute matrix-vector multiplication using element-wise operations and reduction + o = tl.sum(q[:, None] * kv, axis=0) + + # Store output with masking + tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # [gpu_output] (dict): + # [grid] : + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _decode_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_deepep_compute_src2dst_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_deepep_compute_src2dst_triton_kernel.py new file mode 100644 index 000000000..1689b404c --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_deepep_compute_src2dst_triton_kernel.py @@ -0,0 +1,38 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def deepep_compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + num_invalid = tl.load(num_minus_one) + tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask) + + +def test_deepep_compute_src2dst_triton_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + deepep_compute_src2dst_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_deepep_permute_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_deepep_permute_triton_kernel.py new file mode 100644 index 000000000..be1faeeae --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_deepep_permute_triton_kernel.py @@ -0,0 +1,62 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +# source : python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def deepep_permute_triton_kernel( + input_ptr, + gateup_input_ptr, + src2dst_ptr, + topk_ids_ptr, + a1_scales_ptr, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + OutDtype = gateup_input_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + + src_ptr = input_ptr + src_idx * hidden_size + + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype) + + for idx in range(topk): + dst_idx = tl.load(src2dst_ptr + idx) + if dst_idx >= 0: + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + tl.store(dst_ptr + offset, in_data, mask=mask) + +def test_deepep_permute_triton_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + deepep_permute_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_deepep_post_reorder_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_deepep_post_reorder_triton_kernel.py new file mode 100644 index 000000000..b65566e7d --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_deepep_post_reorder_triton_kernel.py @@ -0,0 +1,59 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def deepep_post_reorder_triton_kernel( + down_output_ptr, + output_ptr, + src2dst_ptr, + topk_ids_ptr, + topk_weights_ptr, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + InDtype = down_output_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + topk_weights_ptr = topk_weights_ptr + src_idx * topk + + store_ptr = output_ptr + src_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + for idx in range(topk): + dst_idx = tl.load(src2dst_ptr + idx) + if dst_idx >= 0: + weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) + load_ptr = down_output_ptr + dst_idx * hidden_size + in_data = tl.load(load_ptr + offset, mask=mask) + sum_vec += in_data * weigh_scale + tl.store(store_ptr + offset, sum_vec, mask=mask) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + deepep_post_reorder_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_deepgemm_compute_src2dst_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_deepgemm_compute_src2dst_triton_kernel.py new file mode 100644 index 000000000..df66a2d45 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_deepgemm_compute_src2dst_triton_kernel.py @@ -0,0 +1,48 @@ +import sys +import pytest +import torch + +import triton +import triton.language as tl + + +sys.path.append("..") +import test_common + + +#source python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def deepgemm_compute_src2dst_triton_kernel( + topk_ids, + reorder_ids, + seg_indptr, + src2dst, + m_max, + num_toks, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks)) + expert_dst_start = tl.load(seg_indptr + expert_id) + expert_dst_offset = dst_id - expert_dst_start + dst_id = expert_id * m_max + expert_dst_offset + tl.store(src2dst + src_id, dst_id, mask=mask) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + deepgemm_compute_src2dst_triton_kernel[data['grid']](**input_data) + + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_extend_fwd_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_extend_fwd_kernel.py new file mode 100644 index 000000000..e3018014a --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_extend_fwd_kernel.py @@ -0,0 +1,298 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +import test_common +sys.path.append("..") + + +# source: python/sglang/srt/layers/attention/triton_ops/extend_attention.py +@triton.jit +def _fwd_kernel( + Q_Extend, + K_Extend, + V_Extend, + O_Extend, + K_Buffer, + V_Buffer, + qo_indptr, + kv_indptr, + kv_indices, + mask_ptr, + mask_indptr, + sm_scale, + kv_group_num, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + SLIDING_WINDOW_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + USE_CUSTOM_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + SKIP_PREFIX_CUSTOM_MASK: tl.constexpr, + STORE_TRANSPOSE: tl.constexpr, +): + cur_seq = tl.program_id(0) + cur_head = tl.program_id(1) + cur_block_m = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + + cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq) + cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx + cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq) + cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx + cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend + + if USE_CUSTOM_MASK: + cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + offs_m = tl.arange(0, BLOCK_M) + mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + + mask_d = offs_d < Lq + mask_dv = offs_dv < Lv + + offs_q = ( + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + q = tl.load( + Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0 + ) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + offs_qpe = ( + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_dpe[None, :] + ) + qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) + + # stage 1: compute scores with prefix + offs_n = tl.arange(0, BLOCK_N) + + acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) + deno = tl.zeros([BLOCK_M], dtype=tl.float32) + e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + + for start_n in range(0, cur_seq_len_prefix, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_seq_len_prefix + + offs_kv_loc = tl.load( + kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0 + ) + + # load k in transposed way + offs_buf_k = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q.to(k.dtype), k) + if BLOCK_DPE > 0: + offs_kpe = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe.to(kpe.dtype), kpe) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + final_mask = mask_m[:, None] & mask_n[None, :] + if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK: + custom_mask = tl.load( + mask_ptr + + cur_seq_mask_start_idx + + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + + start_n + + offs_n[None, :], + mask=(mask_m[:, None] & mask_n[None, :]), + other=0, + ) + final_mask &= custom_mask + if SLIDING_WINDOW_SIZE > 0: + # Add mask where q_id <= kv_id + sliding_window_size + window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= ( + start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE + ) + final_mask &= window_mask + qk = tl.where(final_mask, qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_buf_v = ( + offs_kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + # stage 2: compute the triangle part + + cur_block_m_end = ( + cur_seq_len_extend + if not IS_CAUSAL + else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) + ) + for start_n in range(0, cur_block_m_end, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_block_m_end + + # load k in transposed way + offs_k = ( + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] + ) + k = tl.load( + K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q, k, out_dtype=tl.float32) + if BLOCK_DPE > 0: + offs_kpe = ( + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + + cur_kv_head * stride_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Extend + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe, kpe) + + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + if USE_CUSTOM_MASK: + custom_mask = tl.load( + mask_ptr + + cur_seq_mask_start_idx + + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + + cur_seq_len_prefix + + start_n + + offs_n[None, :], + mask=(mask_m[:, None] & mask_n[None, :]), + other=0, + ) + custom_mask &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(custom_mask, qk, float("-inf")) + elif IS_CAUSAL: + mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( + start_n + offs_n[None, :] + ) + mask_causual &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_causual, qk, float("-inf")) + else: + mask_non_causal = mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_non_causal, qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_v = ( + (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs + + cur_kv_head * stride_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + offs_o = ( + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_obs + + cur_head * stride_oh + + offs_dv[None, :] + ) + if STORE_TRANSPOSE: + tl.store( + O_Extend + offs_o.T, + (acc / deno[:, None]).T, + mask=(mask_m[:, None] & mask_dv[None, :]).T, + ) + else: + tl.store( + O_Extend + offs_o, + acc / deno[:, None], + mask=mask_m[:, None] & mask_dv[None, :], + ) + + +def test_extend_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _fwd_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fill_gateup_input_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_fill_gateup_input_triton_kernel.py new file mode 100644 index 000000000..24af61ca3 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fill_gateup_input_triton_kernel.py @@ -0,0 +1,71 @@ +import sys +import pytest +import torch + +import triton +import triton.language as tl + +sys.path.append("..") +import test_common + +#source python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def fill_gateup_input_triton_kernel( + input_ptr, + scale_ptr, + gateup_input_ptr, + gateup_input_scale_ptr, + src2dst_ptr, + topk_ids_ptr, + start_expert_id, + end_expert_id, + topk, + m_max, + hidden_size, + scale_size, + BLOCK_SIZE: tl.constexpr, +): + + src_idx_int32 = tl.program_id(0) + src_idx = src_idx_int32.to(tl.int64) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + src_ptr = input_ptr + src_idx * hidden_size + scale_src_ptr = scale_ptr + src_idx * scale_size + + vec = tl.arange(0, BLOCK_SIZE) + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + dst_idx_int32 = tl.load(src2dst_ptr + idx) + dst_idx = dst_idx_int32.to(tl.int64) + dst_idx = dst_idx - start_expert_id * m_max + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + vec + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask) + tl.store(dst_ptr + offset, in_data, mask=mask) + scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size + for start_offset in tl.range(0, scale_size, BLOCK_SIZE): + offset = start_offset + vec + mask = offset < scale_size + in_scale = tl.load(scale_src_ptr + offset, mask=mask) + tl.store(scale_dst_ptr + offset, in_scale, mask=mask) + + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + fill_gateup_input_triton_kernel[data['grid']](**input_data) + + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_filter_finished_cache_loc_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_filter_finished_cache_loc_kernel.py new file mode 100644 index 000000000..65511a5db --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_filter_finished_cache_loc_kernel.py @@ -0,0 +1,60 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/speculative/eagle_utils.py +@triton.jit +def filter_finished_cache_loc_kernel( + out_cache_loc, + tgt_cache_loc, + accept_length, + accept_length_filter, + bs_upper: tl.constexpr, + num_verify_tokens_upper: tl.constexpr, +): + bid = tl.program_id(0) + bs_offset = tl.arange(0, bs_upper) + + accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) + old_start = tl.sum(accept_length_all) + bid + + accept_length_filter_all = tl.load( + accept_length_filter + bs_offset, mask=bs_offset < bid + ) + new_start = tl.sum(accept_length_filter_all) + + copy_len = tl.load(accept_length_filter + bid) + copy_offset = tl.arange(0, num_verify_tokens_upper) + value = tl.load( + tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len + ) + tl.store( + out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len + ) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # [gpu_output] (dict): + # [grid] : + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + filter_finished_cache_loc_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fused_dual_residual_rmsnorm_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_fused_dual_residual_rmsnorm_kernel.py new file mode 100644 index 000000000..18e2ce539 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fused_dual_residual_rmsnorm_kernel.py @@ -0,0 +1,79 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\elementwise.py +@triton.jit +def fused_dual_residual_rmsnorm_kernel( + output_ptr, + mid_ptr, + activ_ptr, + residual_ptr, + weight1_ptr, + weight2_ptr, + eps: tl.constexpr, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + input_start = pid * hidden_dim + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < hidden_dim + + a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0) + a = a_.to(tl.float32) + rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps) + + r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0) + w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0) + w1 = w1_.to(tl.float32) + + a2r = r + (a / rms * w1).to(r.dtype) + tl.store( + mid_ptr + input_start + offsets, + a2r, + mask=mask, + ) + + a2r = a2r.to(tl.float32) + rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps) + + w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0) + w2 = w2_.to(tl.float32) + + tl.store( + output_ptr + input_start + offsets, + a2r / rms2 * w2, # implicitly casts to output dtype here + mask=mask, + ) + + +def test_fused_dual_residual_rmsnorm_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + fused_dual_residual_rmsnorm_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fused_moe_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_fused_moe_kernel.py new file mode 100644 index 000000000..62182e091 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fused_moe_kernel.py @@ -0,0 +1,246 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, + even_Ks: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + offs_token = offs_token.to(tl.int64) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) + if use_int8_w8a16: + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8 or use_int8_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + # channel-wise + elif per_channel_quant: + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + if even_Ks: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + # fix out of shared memory issue + if use_fp8_w8a8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +def test_fused_moe_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + fused_moe_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fused_moe_router_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_fused_moe_router_kernel.py new file mode 100644 index 000000000..baa1050d0 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fused_moe_router_kernel.py @@ -0,0 +1,123 @@ +import triton +import triton.language as tl +import torch +import pytest +import os +import test_common + +@triton.jit +def fused_moe_router_kernel( + input_ptr, # input (bs, hidden_dim) + moe_router_weight_ptr, # input (num_experts, hidden_dim) + topk_weights_ptr, # output (bs, topk) + topk_ids_ptr, # output (bs, topk) + correction_bias_ptr, + is_correction_bias: tl.constexpr, + num_experts: tl.constexpr, + topk: tl.constexpr, + moe_softcapping: tl.constexpr, + moe_renormalize: tl.constexpr, # not supported + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < hidden_dim + + # moe_router_weight is k major + expert_offsets = tl.arange(0, num_experts)[:, None] + router_mask = mask[None, :] + w_router = tl.load( + moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :], + mask=router_mask, + other=0.0, + ) + + x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0) + + # todo: tl.dot? + logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1) + + # logit softcap + logits_scaled = logits / moe_softcapping + exped = tl.exp(2 * logits_scaled) + top = exped - 1 + bottom = exped + 1 + logits_softcapped = top / bottom * moe_softcapping + + # Add bias after softcapping + if is_correction_bias: + bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts)) + logits_softcapped = logits_softcapped + bias + + # topk + top1 = tl.argmax(logits_softcapped, axis=0) + tl.store(topk_ids_ptr + pid * topk + 0, top1) + + top1_v = tl.max(logits_softcapped, axis=0) + invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0) + + tl.store( + topk_weights_ptr + pid * topk + 0, + invsumexp, + ) + + if topk >= 2: + top2 = tl.argmax( + tl.where( + tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf") + ), + axis=0, + ) + tl.store(topk_ids_ptr + pid * topk + 1, top2) + top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0) + tl.store( + topk_weights_ptr + pid * topk + 1, + tl.exp(top2_v - top1_v) * invsumexp, + ) + + if topk > 2: + topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype) + topk_mask = tl.where( + tl.arange(0, num_experts) != top1, topk_mask, float("-inf") + ) + topk_mask = tl.where( + tl.arange(0, num_experts) != top2, topk_mask, float("-inf") + ) + for i in range(2, topk): + topi = tl.argmax(logits_softcapped + topk_mask, axis=0) + topk_mask = tl.where( + tl.arange(0, num_experts) != topi, topk_mask, float("-inf") + ) + tl.store(topk_ids_ptr + pid * topk + i, topi) + topi_v = tl.sum( + logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0 + ) + tl.store( + topk_weights_ptr + pid * topk + i, + tl.exp(topi_v - top1_v) * invsumexp, + ) + +def test_fused_moe_router_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + fused_moe_router_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fused_rmsnorm_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_fused_rmsnorm_kernel.py new file mode 100644 index 000000000..db6e17de0 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fused_rmsnorm_kernel.py @@ -0,0 +1,64 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\elementwise.py +@triton.jit +def fused_rmsnorm_kernel( + output_ptr, + activ_ptr, + weight_ptr, + eps: tl.constexpr, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + input_start = pid * hidden_dim + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < hidden_dim + + a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0) + a = a_.to(tl.float32) + rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps) + + w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0) + w1 = w1_.to(tl.float32) + + a_rms = a / rms * w1 + + tl.store( + output_ptr + input_start + offsets, + a_rms, # implicitly casts to output dtype here + mask=mask, + ) + + +def test_fused_rmsnorm_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + fused_rmsnorm_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fused_softcap_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_fused_softcap_kernel.py new file mode 100644 index 000000000..4e6b87baf --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fused_softcap_kernel.py @@ -0,0 +1,53 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +#source: python\sglang\srt\layers\elementwise.py +@triton.jit +def fused_softcap_kernel( + output_ptr, + input_ptr, + n_ele, + softcap_const: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_ele + x = tl.load(input_ptr + offsets, mask=mask) + fx = x.to(tl.float32) + fxs = fx / softcap_const + exped = tl.exp(2 * fxs) + top = exped - 1 + bottom = exped + 1 + output = top / bottom * softcap_const + tl.store(output_ptr + offsets, output, mask=mask) + +def test_fused_softcap_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + fused_softcap_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fwd_grouped_kernel_stage1.py b/third_party/ascend/test/sglang/v0.4.8/test_fwd_grouped_kernel_stage1.py new file mode 100644 index 000000000..be1a468c5 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fwd_grouped_kernel_stage1.py @@ -0,0 +1,194 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +# source: python\sglang\srt\layers\attention\triton_ops\decode_attention.py +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + kv_indptr, + kv_indices, + Att_Out, + Att_Lse, + num_kv_splits, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if BLOCK_H < kv_group_num: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + if split_kv_end > split_kv_start: + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + if BLOCK_DPE > 0: + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + ) + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + kv_indices + cur_batch_kv_start_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) + + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store( + Att_Lse + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def test_fwd_grouped_kernel_stage1(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _fwd_grouped_kernel_stage1[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_ep_gather.py b/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_ep_gather.py new file mode 100644 index 000000000..3d65e346d --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_ep_gather.py @@ -0,0 +1,100 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +#source: python\sglang\srt\layers\moe\ep_moe\kernels.py + + +@triton.jit +def _fwd_kernel_ep_gather( + total_token_num, + input_tensor, + input_tensor_stride0, + input_tensor_stride1, + recv_topk_ids, + recv_topk_ids_stride0, + recv_topk_ids_stride1, + recv_topk_weight, + recv_topk_weight_stride0, + recv_topk_weight_stride1, + input_index, + input_index_stride0, + input_index_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + topk_num: tl.constexpr, + BLOCK_D: tl.constexpr, +): + cur_block_int32 = tl.program_id(0) + cur_block = cur_block_int32.to(tl.int64) + + start_cur_token_int32 = tl.program_id(1) + + grid_num = tl.num_programs(1) + + for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num): + cur_token = cur_token_int32.to(tl.int64) + + off_d = tl.arange(0, BLOCK_D) + accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) + + for topk_index_int32 in range(0, topk_num): + topk_index = topk_index_int32.to(tl.int64) + + expert_id = tl.load( + recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index + ) + if expert_id >= 0: + source_token_index_int32 = tl.load( + input_index + cur_token * input_index_stride0 + topk_index + ) + source_token_index = source_token_index_int32.to(tl.int64) + + acc_weight = tl.load( + recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index + ) + tmp = tl.load( + input_tensor + + source_token_index * input_tensor_stride0 + + cur_block * BLOCK_D + + off_d + ) + accumulator += tmp.to(tl.float32) * acc_weight + + tl.store( + output_tensor + + cur_token * output_tensor_stride0 + + cur_block * BLOCK_D + + off_d, + accumulator.to(output_tensor.dtype.element_ty), + ) + + +def test_fwd_kernel_ep_gather(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _fwd_kernel_ep_gather[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_flash_decode_stage1.py b/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_flash_decode_stage1.py new file mode 100644 index 000000000..fbd93eccc --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_flash_decode_stage1.py @@ -0,0 +1,144 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\attention\triton_ops\double_sparsity_attention.py +@triton.jit +def _fwd_kernel_flash_decode_stage1( + Q, + K, + V, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + cur_kv_head = cur_head // gqa_group_size + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_end_index = tl.minimum( + cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ + ) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + + q = tl.load(Q + off_q) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] + k = tl.load( + K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0 + ) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) + v = tl.load( + V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0 + ) + + cur_max_logic = tl.max(att_value, axis=0) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale + acc += tl.sum(exp_logic[:, None] * v, axis=0) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) + max_logic = new_max_logic + + need_store = tl.where(block_n_size == 0, 0, 1) + for _ in range(0, need_store, 1): + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + seq_start_block * stride_mid_os + + offs_d + ) + off_mid_o_logexpsum = ( + cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block + ) + tl.store(Mid_O + off_mid_o, acc / sum_exp) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) + return + + +def test_fwd_kernel_flash_decode_stage1(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _fwd_kernel_flash_decode_stage1[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_stage1.py b/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_stage1.py new file mode 100644 index 000000000..e9c010a05 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_stage1.py @@ -0,0 +1,175 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +import test_common +sys.path.append("..") + + +# source: python/sglang/srt/layers/attention/triton_ops/decode_attention.py +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + kv_indptr, + kv_indices, + Att_Out, + Att_Lse, + num_kv_splits, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + kv_splits = tl.load(num_kv_splits + cur_batch) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0) + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + kv_indices + cur_batch_kv_start_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), + ) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store( + Att_Lse + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +def test_fwd_kernel_stage1(ptfile_path="/home/z00946769/triton-sglang/full_data/test_data_fwd_kernel_stage1_full.pt"): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _fwd_kernel_stage1[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + dict_ref, dict_cal = data["gpu_output"], input_data + keys_ref, keys_cal = set(dict_ref.keys()), set(dict_cal.keys()) + if not keys_ref.issubset(keys_cal): + raise ValueError("The keys of dict_ref is not subset of dict_cal") + + for key in dict_ref.keys(): + val_a, val_b = dict_ref[key], dict_cal[key] + if type(val_a) != type(val_b): + raise ValueError("The data type of two dicts are different") + + if isinstance(val_a, torch.Tensor): + # In fwd_kernel_stage1, the actual accuracy of NPU is higher than that of GPU, + # so the restrictions are appropriately loosened when comparing with GPU. + torch.testing.assert_close(val_a.cpu(), val_b.cpu(), rtol=5e-03, atol=5e-03, equal_nan=True) + else: + raise ValueError("Non-tensor type is not currently supported") + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_stage2.py b/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_stage2.py new file mode 100644 index 000000000..eb18cd7f8 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_fwd_kernel_stage2.py @@ -0,0 +1,96 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +import test_common +sys.path.append("..") + + +# source: python/sglang/srt/layers/attention/triton_ops/decode_attention.py +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load( + kv_indptr + cur_batch + ) + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def test_fwd_kernel_stage2(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _fwd_kernel_stage2[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_gelu_and_mul_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_gelu_and_mul_kernel.py new file mode 100644 index 000000000..e7a5901a4 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_gelu_and_mul_kernel.py @@ -0,0 +1,71 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\elementwise.py +@triton.jit +def gelu_and_mul_kernel( + out_hidden_states_ptr, # (bs, hidden_dim) + out_scales_ptr, # (bs,) + hidden_states_ptr, # (bs, hidden_dim * 2) + quant_max: tl.constexpr, + static_scale: tl.constexpr, + hidden_dim: tl.constexpr, # the output hidden_dim + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + input_start = pid * hidden_dim * 2 + output_start = pid * hidden_dim + + input1_offs = tl.arange(0, BLOCK_SIZE) + mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output + input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE) + output_offs = tl.arange(0, BLOCK_SIZE) + + x1 = tl.load( + hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0 + ).to(tl.float32) + x3 = tl.load( + hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0 + ).to(tl.float32) + + # gelu + # cast down before mul to better match training? + gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1 + out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty) + + if quant_max is not None: + raise NotImplementedError() + + tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask) + + +def test_gelu_and_mul_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + gelu_and_mul_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_gelu_and_mul_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_gelu_and_mul_triton_kernel.py new file mode 100644 index 000000000..5202535c0 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_gelu_and_mul_triton_kernel.py @@ -0,0 +1,96 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +@triton.jit +def tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +# source: python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def gelu_and_mul_triton_kernel( + gateup_output, + down_input, + hidden_size, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + BLOCK_SIZE: tl.constexpr, +): + InDtype = gateup_output.dtype.element_ty + OutDtype = down_input.dtype.element_ty + + half_hidden_size = hidden_size // 2 + + pid = tl.program_id(0) + expert_id = tl.load(reorder_topk_ids + pid) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + gateup_output_ptr = gateup_output + pid * hidden_size + gate_output_ptr = gateup_output_ptr + up_output_ptr = gateup_output_ptr + half_hidden_size + down_input_ptr = down_input + pid * half_hidden_size + + if scales is not None: + scale = tl.load(scales + expert_id - start_expert_id) + scale = (1 / scale).to(InDtype) + else: + scale = 1 + + for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < half_hidden_size + + gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) + up_output = tl.load(up_output_ptr + offset, mask=mask) + + # gelu & mul & quantize + kAlpha = 0.7978845608028654 + gate_output = ( + 0.5 + * gate_output + * ( + 1 + + tanh( + kAlpha + * ( + gate_output + + 0.044715 * gate_output * gate_output * gate_output + ) + ) + ) + ) + gate_output = gate_output.to(InDtype) + + gelu_mul_output = gate_output * up_output * scale + gelu_mul_output = gelu_mul_output.to(OutDtype) + tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask) + + +def test_gelu_and_mul_triton_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + gelu_and_mul_triton_kernel[data["grid"]](**input_data) + + result_npu = input_data["down_input"].cpu() + result_npu_32 = result_npu.to(torch.float32) + + gpu_output = data['gpu_output'] + gpu_output = gpu_output["down_input"] + # compare the results of GPU and NPU. + try: + test_common.verify_precision_by_gold_standard(result_npu_32, result_npu, gpu_output) + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_get_last_loc_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_get_last_loc_kernel.py new file mode 100644 index 000000000..cb3156ac1 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_get_last_loc_kernel.py @@ -0,0 +1,50 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\managers\schedule_batch.py +@triton.jit +def get_last_loc_kernel( + req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + result, + num_tokens, + req_to_token_stride, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + mask = offset < num_tokens + + prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) + req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) + + token_mask = prefix_lens > 0 + token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) + tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) + + tl.store(result + offset, tokens, mask=mask) + + +def test_get_last_loc_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + get_last_loc_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_get_num_kv_splits_triton.py b/third_party/ascend/test/sglang/v0.4.8/test_get_num_kv_splits_triton.py new file mode 100644 index 000000000..7c29a0a08 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_get_num_kv_splits_triton.py @@ -0,0 +1,84 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\attention\triton_backend.py +@triton.jit +def get_num_kv_splits_triton( + num_kv_splits_ptr, + seq_lens_ptr, + num_seq, + num_group, + num_head, + num_kv_head, + max_kv_splits, + device_core_count, + MAX_NUM_SEQ: tl.constexpr, +): + offs_seq = tl.arange(0, MAX_NUM_SEQ) + mask_seq = offs_seq < num_seq + + seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0) + max_seq_len = tl.max(seq_lens) + seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len) + min_seq_len = tl.min(seq_lens) + if max_seq_len * 8 < min_seq_len * 10: + min_seq_len = max_seq_len + max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits) + kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1) + + # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually + ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0 + ext_device_core_count = tl.cast( + device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32 + ) + block_h, num_kv_group = 16, num_head // num_kv_head + if num_kv_group == 1: + token_grid = num_seq * num_group * num_head + else: + # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd + block_h = tl.minimum(block_h, num_kv_group) + token_grid = num_seq * num_group * tl.cdiv(num_head, block_h) + max_kv_splits_2 = tl.minimum( + tl.cdiv(ext_device_core_count, token_grid), max_kv_splits + ) + kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2) + + num_kv_splits = tl.maximum( + tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2) + ) + + offs_token = offs_seq * num_group + mask_token = offs_token < num_seq * num_group + for i in range(0, num_group): + tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token) + + +def test_get_num_kv_splits_triton(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + get_num_kv_splits_triton[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_get_target_cache_loc.py b/third_party/ascend/test/sglang/v0.4.8/test_get_target_cache_loc.py new file mode 100644 index 000000000..e87984830 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_get_target_cache_loc.py @@ -0,0 +1,76 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/speculative/eagle_utils.py +@triton.jit +def get_target_cache_loc( + tgt_cache_loc, + to_free_slots, + accept_length, + to_free_num_slots, + out_cache_loc, + num_verify_tokens: tl.constexpr, + num_verify_tokens_upper: tl.constexpr, + bs_upper: tl.constexpr, +): + bid = tl.program_id(axis=0) + offset = tl.arange(0, num_verify_tokens_upper) + bs_offset = tl.arange(0, bs_upper) + + # write the first part to tgt_cache_loc + accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) + tgt_cache_loc_start = tl.sum(accept_len_all) + bid + copy_len = tl.load(accept_length + bid) + 1 + out_cache_loc_row = tl.load( + out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len + ) + tl.store( + tgt_cache_loc + tgt_cache_loc_start + offset, + out_cache_loc_row, + mask=offset < copy_len, + ) + + # write the second part to to_free_num_pages + to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid) + to_free_num_slots_cur = tl.load(to_free_num_slots + bid) + out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur + to_free_slots_start = tl.sum(to_free_num_slots_all) + + copy_len = to_free_num_slots_cur + out_cache_loc_row = tl.load( + out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset, + mask=offset < copy_len, + ) + tl.store( + to_free_slots + to_free_slots_start + offset, + out_cache_loc_row, + mask=offset < copy_len, + ) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # [gpu_output] (dict): + # [grid] : + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + get_target_cache_loc[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_grouped_gemm_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_grouped_gemm_triton_kernel.py new file mode 100644 index 000000000..6e1ce24ef --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_grouped_gemm_triton_kernel.py @@ -0,0 +1,159 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +import test_common +sys.path.append("..") + + +# source: python/sglang/srt/layers/attention/triton_ops/decode_attention.py +@triton.jit +def compute_m_range( + pid, + batch_size, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + BLOCK_SIZE_M: tl.constexpr, +): + idx = 0 + for bs in range(batch_size): + tiles = tl.load(m_num_tiles_indptr + bs) + if pid >= tiles: + idx = bs + + idx_start = tl.load(m_num_tiles_indptr + idx) + + m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M + m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M) + expert_id = tl.load(weight_indices + idx) + return m_range_start, m_range_end, expert_id + + +@triton.jit +def grouped_gemm_triton_kernel( + a, + b, + c, + batch_size, + N, + K, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + scale_a, + scale_b, + use_fp8_w8a8: tl.constexpr, + group_n: tl.constexpr, + group_k: tl.constexpr, + a_stride_0: tl.constexpr, + b_stride_0: tl.constexpr, + b_stride_1: tl.constexpr, + as_stride_0: tl.constexpr, + as_stride_1: tl.constexpr, + bs_stride_0: tl.constexpr, + bs_stride_2: tl.constexpr, + bs_stride_1: tl.constexpr, + use_per_token_if_dynamic: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + c_dtype = c.dtype.element_ty + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + total_m_block = tl.load(m_num_tiles_indptr + batch_size) + if pid_m >= total_m_block: + return + + m_range_start, m_range_end, expert_id = compute_m_range( + pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M + ) + if m_range_end - m_range_start == 0: + return + + n_range_start = pid_n * BLOCK_SIZE_N + n_range_end = min(n_range_start + BLOCK_SIZE_N, N) + + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0) + offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :] + b_ptr = b + ( + (expert_id * b_stride_0) + + (n_range_start + offs_bn[:, None]) * b_stride_1 + + offs_k[None, :] + ) + + if group_k > 0 and group_n > 0: + a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0 + offs_bsn = (n_range_start + offs_bn) // group_n + b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_tile = tl.load( + a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + b_tile = tl.load( + b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1) + b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2) + accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :] + else: + accumulator = tl.dot(a_tile, b_tile.T, accumulator) + a_ptr += BLOCK_SIZE_K + b_ptr += BLOCK_SIZE_K + + if use_fp8_w8a8 and not (group_k > 0 and group_n > 0): + if use_per_token_if_dynamic: + scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None])) + else: + scale_a_value = tl.load(scale_a + expert_id) + scale_b_value = tl.load(scale_b + expert_id) + accumulator *= scale_a_value * scale_b_value + + c_tile = accumulator.to(c_dtype) + + offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N) + c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :] + c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end) + tl.store(c_ptr, c_tile, mask=c_mask) + + +def test_fwd_kernel_stage2(ptfile_path="/home/z00946769/triton-sglang/full_data/test_grouped_gemm_triton_kernel_full.pt"): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + grouped_gemm_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_logits_processor_fused_softcap_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_logits_processor_fused_softcap_kernel.py new file mode 100644 index 000000000..7dfd04c56 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_logits_processor_fused_softcap_kernel.py @@ -0,0 +1,62 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +#source: python\sglang\srt\layers\logits_processor.py + + +@triton.jit +def fused_softcap_kernel( + full_logits_ptr, + softcapping_value, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0).to(tl.int64) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load values + x = tl.load(full_logits_ptr + offsets, mask=mask) + + # Perform operations in-place + x = x / softcapping_value + + # Manual tanh implementation using exp + exp2x = tl.exp(2 * x) + x = (exp2x - 1) / (exp2x + 1) + + x = x * softcapping_value + + # Store result + tl.store(full_logits_ptr + offsets, x, mask=mask) + + +def test_logits_processor_fused_softcap_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + fused_softcap_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_memcpy_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_memcpy_triton_kernel.py new file mode 100644 index 000000000..3c3f7e9cc --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_memcpy_triton_kernel.py @@ -0,0 +1,57 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +# source: python\sglang\srt\layers\dp_attention.py +@triton.jit +def memcpy_triton_kernel( + dst_ptr, + src_ptr, + offset_ptr, + sz_ptr, + offset_src: tl.constexpr, + chunk_size, # multiplied for offset and sz + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0).to(tl.int64) + offset = tl.load(offset_ptr).to(tl.int64) * chunk_size + sz = tl.load(sz_ptr).to(tl.int64) * chunk_size + + start_index = pid * BLOCK_SIZE + offs = tl.arange(0, BLOCK_SIZE) + mask = start_index + offs < sz + + if offset_src: + data = tl.load(src_ptr + offset + start_index + offs, mask=mask) + tl.store(dst_ptr + start_index + offs, data, mask=mask) + else: + data = tl.load(src_ptr + start_index + offs, mask=mask) + tl.store(dst_ptr + offset + start_index + offs, data, mask=mask) + +def test_memcpy_triton_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + memcpy_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_merge_state_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_merge_state_kernel.py new file mode 100644 index 000000000..857c8084d --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_merge_state_kernel.py @@ -0,0 +1,86 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: sgl-kernel/tests/test_merge_state.py +@triton.jit +def state_merge(o, m, d, other_o, other_m, other_d): + m_max = tl.maximum(m, other_m) + d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max) + o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max) + return o, m_max, d + + +@triton.jit +def state_normalize(o, m, d): + o = o / d + return o, m, d + + +@triton.jit +def state_get_lse(o, m, d): + return m + tl.log2(d) + + +@triton.jit +def merge_state_kernel( + v_a_ptr, + s_a_ptr, + v_b_ptr, + s_b_ptr, + v_merged_ptr, + s_merged_ptr, + num_heads, + head_dim, + bdx: tl.constexpr, + bdy: tl.constexpr, +): + pos = tl.program_id(axis=0) + for tx in tl.range(bdx): + for head_idx in tl.range(bdy): + s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx) + s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx) + + offsets = (pos * num_heads + head_idx) * head_dim + tx + v_a = tl.load(v_a_ptr + offsets) + v_b = tl.load(v_b_ptr + offsets) + + v_merged, s_max, d = state_merge( + o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1 + ) + v_merged, s_max, d = state_normalize(v_merged, s_max, d) + v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx + tl.store(v_merged_ptr + v_merged_offset, v_merged) + + if s_merged_ptr: + tl.store( + s_merged_ptr + pos * num_heads + head_idx, + tl.log2(d) + s_max, + ) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # [gpu_output] (dict): + # [grid] : + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + merge_state_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_moe_align_block_size_stage1.py b/third_party/ascend/test/sglang/v0.4.8/test_moe_align_block_size_stage1.py new file mode 100644 index 000000000..3305c9bc6 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_moe_align_block_size_stage1.py @@ -0,0 +1,49 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: sgl-kernel/benchmark/bench_moe_align_block_size.py +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = pid * tokens_per_thread + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # [gpu_output] (dict): + # [grid] : + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + moe_align_block_size_stage1[data["grid"]](**input_data) + + # compare the results of GPU and NPU + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_moe_align_block_size_stage2.py b/third_party/ascend/test/sglang/v0.4.8/test_moe_align_block_size_stage2.py new file mode 100644 index 000000000..2c3319657 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_moe_align_block_size_stage2.py @@ -0,0 +1,43 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: sgl-kernel/benchmark/bench_moe_align_block_size.py +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # [gpu_output] (dict): + # [grid] : + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + moe_align_block_size_stage2[data["grid"]](**input_data) + + # compare the results of GPU and NPU + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_moe_align_block_size_stage3.py b/third_party/ascend/test/sglang/v0.4.8/test_moe_align_block_size_stage3.py new file mode 100644 index 000000000..7cba61d0e --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_moe_align_block_size_stage3.py @@ -0,0 +1,46 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: +# sgl-kernel/tests/test_moe_align.py +# sgl-kernel/benchmark/bench_moe_align_block_size.py +# python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +def test_moe_align_block_size_stage3(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + moe_align_block_size_stage3[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_moe_sum_reduce_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_moe_sum_reduce_kernel.py new file mode 100644 index 000000000..ef688b6d6 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_moe_sum_reduce_kernel.py @@ -0,0 +1,83 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +#source: python\sglang\srt\layers\moe\fused_moe_triton\fused_moe.py + + +@triton.jit +def _moe_sum_reduce_kernel( + input_ptr, + input_stride_0, + input_stride_1, + input_stride_2, + output_ptr, + output_stride_0, + output_stride_1, + token_num: int, + topk_num: int, + hidden_dim: int, + routed_scaling_factor: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DIM: tl.constexpr, + NUM_STAGE: tl.constexpr, +): + input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) + input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) + output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) + + token_block_id = tl.program_id(0) + dim_block_id = tl.program_id(1) + + token_start = token_block_id * BLOCK_M + token_end = min((token_block_id + 1) * BLOCK_M, token_num) + + dim_start = dim_block_id * BLOCK_DIM + dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) + + offs_dim = dim_start + tl.arange(0, BLOCK_DIM) + + for token_index in range(token_start, token_end): + accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) + input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim + for i in tl.range(0, topk_num, num_stages=NUM_STAGE): + tmp = tl.load( + input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 + ) + accumulator += tmp + accumulator = accumulator * routed_scaling_factor + store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim + tl.store( + store_t_ptr, + accumulator.to(input_ptr.dtype.element_ty), + mask=offs_dim < dim_end, + ) + + +def test_moe_sum_reduce_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _moe_sum_reduce_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_post_reorder_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_post_reorder_triton_kernel.py new file mode 100644 index 000000000..d121cfd66 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_post_reorder_triton_kernel.py @@ -0,0 +1,80 @@ +import sys +import pytest +import torch + +import triton +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def post_reorder_triton_kernel( + down_output_ptr, + output_ptr, + src2dst_ptr, + topk_ids_ptr, + topk_weights_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + dst_start, + BLOCK_SIZE: tl.constexpr, +): + InDtype = down_output_ptr.dtype.element_ty + + src_idx_int32 = tl.program_id(0) + src_idx = src_idx_int32.to(tl.int64) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + topk_weights_ptr = topk_weights_ptr + src_idx * topk + + computed = False + store_ptr = output_ptr + src_idx * hidden_size + + vec = tl.arange(0, BLOCK_SIZE) + + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + vec + mask = offset < hidden_size + + sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + computed = True + dst_idx_int32 = tl.load(src2dst_ptr + idx) + dst_idx = dst_idx_int32.to(tl.int64) + dst_idx = dst_idx - dst_start + weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) + load_ptr = down_output_ptr + dst_idx * hidden_size + in_data = tl.load(load_ptr + offset, mask=mask) + sum_vec += in_data * weigh_scale + tl.store(store_ptr + offset, sum_vec, mask=mask) + + if computed == False: + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + vec + mask = offset < hidden_size + tl.store( + store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask + ) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + post_reorder_triton_kernel[data['grid']](**input_data) + + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_pre_reorder_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_pre_reorder_triton_kernel.py new file mode 100644 index 000000000..49772cb51 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_pre_reorder_triton_kernel.py @@ -0,0 +1,71 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def pre_reorder_triton_kernel( + input_ptr, + gateup_input_ptr, + src2dst_ptr, + topk_ids_ptr, + a1_scales_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, + use_per_token_if_dynamic: tl.constexpr, +): + OutDtype = gateup_input_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + src_ptr = input_ptr + src_idx * hidden_size + + vec = tl.arange(0, BLOCK_SIZE) + + if a1_scales_ptr is not None and use_per_token_if_dynamic: + scale = 1.0 / tl.load(a1_scales_ptr + src_idx) + + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + if a1_scales_ptr is not None: + if not use_per_token_if_dynamic: + scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) + else: + scale = 1.0 + + dst_idx = tl.load(src2dst_ptr + idx) + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + vec + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) + out_data = (in_data * scale).to(OutDtype) + tl.store(dst_ptr + offset, out_data, mask=mask) + + +def test_pre_reorder_triton_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + pre_reorder_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_qkv_lora_b_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_qkv_lora_b_kernel.py new file mode 100644 index 000000000..1d01a514c --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_qkv_lora_b_kernel.py @@ -0,0 +1,132 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +# source: python\sglang\srt\lora\triton_ops\qkv_lora_b.py + + +@triton.jit +def _qkv_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Parameters of size + K, # K = R + max_qkv_out_dim, # max(output_q_dim, output_kv_dim) + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + # Offsets of q/k/v slice on output dimension + n_offs, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scalings, +): + # This kernel packs 3 sgemms (q/k/v) into a single kernel. + + # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank + # weights: (num_lora, N_Q + 2 * N_KV, K) + # output: (s, N_Q + 2 * N_KV) + # N_Q >> K, N_KV >> K + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len. + # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v) + batch_id = tl.program_id(axis=2) + qkv_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + n_start = tl.load(n_offs + qkv_id) + n_size = tl.load(n_offs + qkv_id + 1) - n_start + rank = tl.load(lora_ranks + w_index) + scaling = tl.load(scalings + w_index) + # Adjust K (rank) according to the specific LoRA adapter + K = tl.minimum(K, rank) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iterate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def test_qkv_lora_b_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _qkv_lora_b_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_set_mla_kv_buffer_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_set_mla_kv_buffer_kernel.py new file mode 100644 index 000000000..21ac09b37 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_set_mla_kv_buffer_kernel.py @@ -0,0 +1,73 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/mem_cache/memory_pool.py +@triton.jit +def set_mla_kv_buffer_kernel( + kv_buffer_ptr, + cache_k_nope_ptr, + cache_k_rope_ptr, + loc_ptr, + buffer_stride: tl.constexpr, + nope_stride: tl.constexpr, + rope_stride: tl.constexpr, + nope_dim: tl.constexpr, + rope_dim: tl.constexpr, + BLOCK: tl.constexpr, +): + pid_loc = tl.program_id(0) + pid_blk = tl.program_id(1) + + base = pid_blk * BLOCK + offs = base + tl.arange(0, BLOCK) + total_dim = nope_dim + rope_dim + mask = offs < total_dim + + loc = tl.load(loc_ptr + pid_loc) + dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs + + if base + BLOCK <= nope_dim: + src = tl.load( + cache_k_nope_ptr + pid_loc * nope_stride + offs, + mask=mask, + ) + else: + offs_rope = offs - nope_dim + src = tl.load( + cache_k_rope_ptr + pid_loc * rope_stride + offs_rope, + mask=mask, + ) + + tl.store(dst_ptr, src, mask=mask) + + + +def test_set_mla_kv_buffer_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + set_mla_kv_buffer_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/test/sglang/v0.4.8/test_silu_and_mul_triton_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_silu_and_mul_triton_kernel.py new file mode 100644 index 000000000..b6cc33e78 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_silu_and_mul_triton_kernel.py @@ -0,0 +1,80 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +#source: python\sglang\srt\layers\moe\ep_moe\kernels.py + + +@triton.jit +def silu_and_mul_triton_kernel( + gateup_output, + down_input, + hidden_size, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + BLOCK_SIZE: tl.constexpr, +): + InDtype = gateup_output.dtype.element_ty + OutDtype = down_input.dtype.element_ty + + half_hidden_size = hidden_size // 2 + + pid = tl.program_id(0) + expert_id = tl.load(reorder_topk_ids + pid) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + gateup_output_ptr = gateup_output + pid * hidden_size + gate_output_ptr = gateup_output_ptr + up_output_ptr = gateup_output_ptr + half_hidden_size + down_input_ptr = down_input + pid * half_hidden_size + + if scales is not None: + scale = tl.load(scales + expert_id - start_expert_id) + scale = (1 / scale).to(InDtype) + else: + scale = 1 + + for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < half_hidden_size + + gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) + up_output = tl.load(up_output_ptr + offset, mask=mask) + + # silu & mul & quantize + gate_output = gate_output * tl.sigmoid(gate_output) + gate_output = gate_output.to(InDtype) + + silu_mul_output = gate_output * up_output * scale + silu_mul_output = silu_mul_output.to(OutDtype) + tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) + + +def test_silu_and_mul_triton_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + silu_and_mul_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_sparse_fwd_kernel_flash_decode_stage3.py b/third_party/ascend/test/sglang/v0.4.8/test_sparse_fwd_kernel_flash_decode_stage3.py new file mode 100644 index 000000000..361399d6c --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_sparse_fwd_kernel_flash_decode_stage3.py @@ -0,0 +1,77 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +import test_common +sys.path.append("..") + + +# source: python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +@triton.jit +def _sparse_fwd_kernel_flash_decode_stage3( + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + seq_len, # NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_o_eb, + stride_mid_o_eh, + stride_obs, + stride_oh, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + block_n_size = tl.where(seq_len <= 0, 0, seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + for block_seq_n in range(0, block_n_size, 1): + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + new_max_logic = tl.maximum(tlogic, max_logic) + + old_scale = tl.exp(max_logic - new_max_logic) + acc *= old_scale + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp = sum_exp * old_scale + exp_logic + max_logic = new_max_logic + + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) + return + + +def test_sparse_fwd_kernel_flash_decode_stage3(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _sparse_fwd_kernel_flash_decode_stage3[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_tma_align_input_scale_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_tma_align_input_scale_kernel.py new file mode 100644 index 000000000..cd6cde783 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_tma_align_input_scale_kernel.py @@ -0,0 +1,64 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +#source: python\sglang\srt\layers\moe\ep_moe\kernels.py + + +@triton.jit +def _tma_align_input_scale_kernel( + input_scale_ptr, + output_ptr, + m, + k_div_block_size, + input_scale_stride_m, + input_scale_stride_k, + output_stride_m, + output_stride_k, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + grid_m = tl.num_programs(0) + k_offsets = tl.arange(0, BLOCK_SIZE_K) + + for m_base in range(pid_m, m, grid_m): + input_offset = ( + input_scale_ptr + + m_base * input_scale_stride_m + + k_offsets * input_scale_stride_k + ) + input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size) + + output_offset = ( + output_ptr + k_offsets * output_stride_k + m_base * output_stride_m + ) + tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size) + + +def test_tma_align_input_scale_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _tma_align_input_scale_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_triton_ops_merge_state_kernel.py b/third_party/ascend/test/sglang/v0.4.8/test_triton_ops_merge_state_kernel.py new file mode 100644 index 000000000..6447f8058 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_triton_ops_merge_state_kernel.py @@ -0,0 +1,90 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +import test_common +sys.path.append("..") + + +# source: python/sglang/srt/layers/attention/triton_ops/merge_state.py +@triton.jit +def merge_state_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged + output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a + prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b + suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx) + s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx) + p_lse = float("-inf") if p_lse == float("inf") else p_lse + s_lse = float("-inf") if s_lse == float("inf") else s_lse + + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + out_se = tl.exp(p_lse) + tl.exp(s_lse) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + token_idx * num_heads + head_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load( + prefix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + s_out = tl.load( + suffix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se + out = p_out * p_scale + s_out * s_scale + tl.store( + output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask, + ) + + +def test_triton_ops_merge_state_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + merge_state_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_write_req_to_token_pool_triton.py b/third_party/ascend/test/sglang/v0.4.8/test_write_req_to_token_pool_triton.py new file mode 100644 index 000000000..43eea4a2d --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_write_req_to_token_pool_triton.py @@ -0,0 +1,70 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: \python\sglang\srt\managers\schedule_batch.py +@triton.jit +def write_req_to_token_pool_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(0) + + req_pool_index = tl.load(req_pool_indices + pid) + pre_len = tl.load(pre_lens + pid) + seq_len = tl.load(seq_lens + pid) + + # NOTE: This can be slow for large bs + cumsum_start = tl.cast(0, tl.int64) + for i in range(pid): + cumsum_start += tl.load(extend_lens + i) + + num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < (seq_len - pre_len) + value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) + tl.store( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + offset + + pre_len, + value, + mask=mask, + ) + + +def test_write_req_to_token_pool_triton(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + write_req_to_token_pool_triton[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/test/sglang/v0.4.8/test_write_req_to_token_pool_triton_optimize.py b/third_party/ascend/test/sglang/v0.4.8/test_write_req_to_token_pool_triton_optimize.py new file mode 100644 index 000000000..57c3d4d77 --- /dev/null +++ b/third_party/ascend/test/sglang/v0.4.8/test_write_req_to_token_pool_triton_optimize.py @@ -0,0 +1,74 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + +# source: benchmark\kernels\scheduler_batch\benchmark_write_req_to_token_pool_triton.py +@triton.jit +def write_req_to_token_pool_triton_optimize( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_to_token_ptr_stride: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_token = tl.program_id(1) + + req_pool_index = tl.load(req_pool_indices + pid_batch) + pre_len = tl.load(pre_lens + pid_batch) + seq_len = tl.load(seq_lens + pid_batch) + extend_len = seq_len - pre_len + + cumsum_start = 0 + for i in range(pid_batch): + cumsum_start += tl.load(extend_lens + i) + + token_start = pid_token * BLOCK_SIZE + + offset = tl.arange(0, BLOCK_SIZE) + actual_offset = token_start + offset + mask = actual_offset < extend_len + + src_ptr = out_cache_loc + cumsum_start + actual_offset + src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE) + value = tl.load(src_ptr, mask=mask) + dst_ptr = ( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + actual_offset + + pre_len + ) + dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE) + + tl.store(dst_ptr, value, mask=mask) + +def test_write_req_to_token_pool_triton_optimize(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + write_req_to_token_pool_triton_optimize[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/CMakeLists.txt b/third_party/ascend/triton-adapter/CMakeLists.txt new file mode 100644 index 000000000..ebeb0b8e3 --- /dev/null +++ b/third_party/ascend/triton-adapter/CMakeLists.txt @@ -0,0 +1,36 @@ +# Security compilation options settings +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIE") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIE") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-strong") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstack-protector-strong") +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-z,now -pie -s") +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,now -s") +set(CMAKE_SKIP_RPATH TRUE) +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) +unset(CMAKE_INSTALL_RPATH) + +option(TRITON_ADAPTER_BUILD_CPU_BACKEND "Build triton-adapter CPU backend" ON) + +set(TRITON_ADAPTER_SOURCE_DIR ".") +set(TRITON_ADAPTER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") + +include_directories(./include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files + +# Triton Adaptor is dependent on AscendNPU IR +set(ASCENDNPU_IR_SRC_DIR ${PROJECT_SOURCE_DIR}/third_party/ascend/third_party/ascendnpu-ir) +set(ASCENDNPU_IR_BINARY_DIR ${PROJECT_BINARY_DIR}/third_party/ascend/third_party/ascendnpu-ir) +set(BISHENGIR_BUILD_STANDALONE_IR_ONLY ON) + +add_subdirectory(${ASCENDNPU_IR_SRC_DIR} ${ASCENDNPU_IR_BINARY_DIR}) +include_directories(${ASCENDNPU_IR_SRC_DIR}/bishengir/include) +include_directories(${ASCENDNPU_IR_BINARY_DIR}/bishengir/include) # Tablegen'd files + +add_subdirectory(include) +add_subdirectory(lib) +add_subdirectory(tools) + +if (TRITON_ADAPTER_BUILD_CPU_BACKEND) + add_triton_plugin(TritonAdapter triton_adapter.cc LINK_LIBS TritonToLinalg) +endif() diff --git a/third_party/ascend/triton-adapter/include/CMakeLists.txt b/third_party/ascend/triton-adapter/include/CMakeLists.txt new file mode 100644 index 000000000..e6afa44d4 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(TritonToAnnotation) +add_subdirectory(TritonToHIVM) +add_subdirectory(TritonToLinalg) +add_subdirectory(DiscreteMaskAccessConversion) +add_subdirectory(TritonToUnstructure) diff --git a/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/CMakeLists.txt b/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/CMakeLists.txt new file mode 100644 index 000000000..567a119e5 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name DiscreteMaskAccessConversion) +add_public_tablegen_target(DiscreteMaskAccessConversionPassIncGen) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.h b/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.h new file mode 100644 index 000000000..f22ab03a8 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.h @@ -0,0 +1,34 @@ +#ifndef TRITON_ADAPTER_DISCRETEMASKACCESSCONVERSION_H +#define TRITON_ADAPTER_DISCRETEMASKACCESSCONVERSION_H + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" + +#define GEN_PASS_CLASSES +#include "../../include/DiscreteMaskAccessConversion/Passes.h.inc" + +namespace mlir { +namespace triton { + +std::unique_ptr> createDiscreteMaskAccessConversionPass(); + +} // namespace triton +} // namespace mlir + +namespace { + +using namespace mlir; +using namespace triton; + +class DiscreteMaskAccessConversionPass + : public DiscreteMaskAccessConversionBase { +public: + + void runOnOperation() override; +}; + +} // namespace + +#endif // DISCRETE_MASK_ACCESS_CONVERSION_H \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/Passes.h b/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/Passes.h new file mode 100644 index 000000000..bcda15e6b --- /dev/null +++ b/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/Passes.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#ifndef TRITON_ADAPTER_DISCRETE_MASK_ACCESS_CONVERSION_PASSES_H +#define TRITON_ADAPTER_DISCRETE_MASK_ACCESS_CONVERSION_PASSES_H + +#include "DiscreteMaskAccessConversionPass.h" + +namespace mlir { +namespace triton { + +/// Creates a pass to convert Triton dialect to HIVM dialect. +std::unique_ptr> createDiscreteMaskAccessConversionPass(); + +#define GEN_PASS_REGISTRATION +#include "ascend/triton-adapter/include/DiscreteMaskAccessConversion/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_DISCRETE_MASK_ACCESS_CONVERSION_PASSES_H \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/Passes.td b/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/Passes.td new file mode 100644 index 000000000..8beded0e4 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/DiscreteMaskAccessConversion/Passes.td @@ -0,0 +1,16 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#ifndef DISCRETE_MASK_ACCESS_CONVERSION_PASSES +#define DISCRETE_MASK_ACCESS_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def DiscreteMaskAccessConversion : Pass<"discrete-mask-access-conversion", "mlir::ModuleOp"> { + let summary = "Recognize and convert discrete mask memory access"; + let constructor = "triton::createDiscreteMaskAccessConversionPass()"; +} + +#endif // DISCRETE_MASK_ACCESS_CONVERSION_PASSES diff --git a/third_party/ascend/triton-adapter/include/TritonToAnnotation/CMakeLists.txt b/third_party/ascend/triton-adapter/include/TritonToAnnotation/CMakeLists.txt new file mode 100644 index 000000000..69d72d21e --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToAnnotation/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToAnnotation) +add_public_tablegen_target(TritonToAnnotationConversionPassIncGen) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToAnnotation/Passes.h b/third_party/ascend/triton-adapter/include/TritonToAnnotation/Passes.h new file mode 100644 index 000000000..21ec55337 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToAnnotation/Passes.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#ifndef TRITON_ADAPTER_TRITON_TO_ANNOTATION_CONVERSION_PASSES_H +#define TRITON_ADAPTER_TRITON_TO_ANNOTATION_CONVERSION_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +// Forward declarations. +class ModuleOp; + +namespace triton { + +/// Creates a pass to convert Triton dialect to Annotation dialect. +std::unique_ptr> createTritonToAnnotationPass(); + +#define GEN_PASS_REGISTRATION +#include "ascend/triton-adapter/include/TritonToAnnotation/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_TRITON_TO_ANNOTATION_CONVERSION_PASSES_H \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToAnnotation/Passes.td b/third_party/ascend/triton-adapter/include/TritonToAnnotation/Passes.td new file mode 100644 index 000000000..58598a5c9 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToAnnotation/Passes.td @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#ifndef TRITON_TO_ANNOTATION_CONVERSION_PASSES +#define TRITON_TO_ANNOTATION_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToAnnotation : Pass<"triton-to-annotation", "mlir::ModuleOp"> { + let summary = "Convert Triton to Annotation dialect"; + let constructor = "triton::createTritonToAnnotationPass()"; + let dependentDialects = ["annotation::AnnotationDialect"]; +} + +#endif // TRITON_TO_ANNOTATION_CONVERSION_PASSES diff --git a/third_party/ascend/triton-adapter/include/TritonToHIVM/CMakeLists.txt b/third_party/ascend/triton-adapter/include/TritonToHIVM/CMakeLists.txt new file mode 100644 index 000000000..4db98b26b --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToHIVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToHIVM) +add_public_tablegen_target(TritonToHIVMConversionPassIncGen) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToHIVM/Passes.h b/third_party/ascend/triton-adapter/include/TritonToHIVM/Passes.h new file mode 100644 index 000000000..0903a44f2 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToHIVM/Passes.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#ifndef TRITON_ADAPTER_TRITON_TO_HIVM_CONVERSION_PASSES_H +#define TRITON_ADAPTER_TRITON_TO_HIVM_CONVERSION_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +// Forward declarations. +class ModuleOp; + +namespace triton { + +/// Creates a pass to convert Triton dialect to HIVM dialect. +std::unique_ptr> createTritonToHIVMPass(); + +#define GEN_PASS_REGISTRATION +#include "ascend/triton-adapter/include/TritonToHIVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_TRITON_TO_HIVM_CONVERSION_PASSES_H \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToHIVM/Passes.td b/third_party/ascend/triton-adapter/include/TritonToHIVM/Passes.td new file mode 100644 index 000000000..7aa37c513 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToHIVM/Passes.td @@ -0,0 +1,17 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#ifndef TRITON_TO_HIVM_CONVERSION_PASSES +#define TRITON_TO_HIVM_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToHIVM : Pass<"triton-to-hivm", "mlir::ModuleOp"> { + let summary = "Convert Triton to HIVM dialect"; + let constructor = "triton::createTritonToHIVMPass()"; + let dependentDialects = ["hivm::HIVMDialect"]; +} + +#endif // TRITON_TO_HIVM_CONVERSION_PASSES diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h new file mode 100644 index 000000000..d90ad0dda --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/ArgMinMaxConverter.h @@ -0,0 +1,317 @@ +#ifndef TRITON_ADAPTER_ARGMINMAXCONVERTER_H +#define TRITON_ADAPTER_ARGMINMAXCONVERTER_H + +#include "Utils/Utils.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "ConversionPatterns.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" + +#define DEBUG_TYPE "triton-to-linalg" + +#include "llvm/Support/Debug.h" + +namespace TTOpConverters { +using namespace mlir; +using namespace triton; + +template +class ArgMinMaxBaseConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchTieBreakResult(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &tileBreakValue) const { + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto eqCmpOp = dyn_cast(*it); + if (eqCmpOp) { + if (eqCmpOp.getPredicate() != arith::CmpFPredicate::OEQ || + currValue != eqCmpOp.getLhs() || reduceValue != eqCmpOp.getRhs()) { + return failure(); + } + } + + auto eqCmpIOp = dyn_cast(*it++); + if (eqCmpIOp) { + if (eqCmpIOp.getPredicate() != arith::CmpIPredicate::eq || + currValue != eqCmpIOp.getLhs() || reduceValue != eqCmpIOp.getRhs()) { + return failure(); + } + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto sltCmpOp = dyn_cast(*it++); + if (!sltCmpOp || sltCmpOp.getPredicate() != arith::CmpIPredicate::slt || + currIndex != sltCmpOp.getLhs() || reduceIndex != sltCmpOp.getRhs()) { + return failure(); + } + + // matching: %13 = arith.andi %11, %12 : i1 + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto andOp = dyn_cast(*it++); + + Value cmpOp; + if (eqCmpOp) + cmpOp = eqCmpOp; + else + cmpOp = eqCmpIOp; + + if (!andOp || andOp.getLhs() != cmpOp || andOp.getRhs() != sltCmpOp) { + return failure(); + } + + tileBreakValue = andOp; + return success(); + } + + LogicalResult matchShouldUpdateValue(Value currValue, Value currIndex, + Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, + Value &shouldUpdate) const { + Value tieResult; + if (failed(matchTieBreakResult(currValue, currIndex, reduceValue, + reduceIndex, it, tieResult))) { + LLVM_DEBUG(llvm::dbgs() << "Tie break result match failed\n"); + return failure(); + } + + Value comparisonResult; + if (failed(T::matchComparisonResult(currValue, currIndex, reduceValue, + reduceIndex, it, comparisonResult))) { + LLVM_DEBUG(llvm::dbgs() << "Comparison result match failed\n"); + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + auto orOp = dyn_cast(*it++); + if (!orOp || orOp.getLhs() != comparisonResult || + orOp.getRhs() != tieResult) { + return failure(); + } + + shouldUpdate = orOp; + return success(); + } + + Value getInitTensor(ConversionPatternRewriter &rewriter, + ArrayRef shape, Value fillValue, + Location loc) const { + Value initTensor = + rewriter.create(loc, shape, fillValue.getType()); + return rewriter + .create(loc, ValueRange{fillValue}, + ValueRange{initTensor}) + .result(); + } + +public: + ArgMinMaxBaseConverter(MLIRContext *context) : OpConversionPattern(context) {} + + LogicalResult match(triton::ReduceOp op) const override final { + if (op.getBody()->getNumArguments() != 4) { + return failure(); + } + + auto block = op.getBody(); + auto ops = block->without_terminator(); + + Value currValue = block->getArgument(0); + Value currIndex = block->getArgument(1); + Value reduceValue = block->getArgument(2); + Value reduceIndex = block->getArgument(3); + + auto opsIt = ops.begin(); + Value shouldUpdate; + if (failed(matchShouldUpdateValue(currValue, currIndex, reduceValue, + reduceIndex, opsIt, shouldUpdate))) { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto valueSelectOp = dyn_cast(*opsIt++); + if (!valueSelectOp || valueSelectOp.getCondition() != shouldUpdate || + currValue != valueSelectOp.getTrueValue() || + reduceValue != valueSelectOp.getFalseValue()) { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto indexSelectOp = dyn_cast(*opsIt++); + if (indexSelectOp) { + if (indexSelectOp.getCondition() != shouldUpdate || + currIndex != indexSelectOp.getTrueValue() || + reduceIndex != indexSelectOp.getFalseValue()) { + return failure(); + } + } else { + return failure(); + } + if (!indexSelectOp || indexSelectOp.getCondition() != shouldUpdate || + currIndex != indexSelectOp.getTrueValue() || + reduceIndex != indexSelectOp.getFalseValue()) { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *opsIt << "\n"); + auto termOp = dyn_cast(*opsIt++); + if (!(termOp && termOp == block->getTerminator() && + termOp.getOperands() == + ArrayRef{valueSelectOp, indexSelectOp})) { + return failure(); + } + return success(); + } + + void rewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override final { + auto loc = op.getLoc(); + auto elemTypes = op.getElementTypes(); + + auto valueType = elemTypes[0]; + // tl.argmin reorder + auto block = op.getBody(); + if (isa(valueType)) { + arith::CmpFOp cmpFOp; + block->walk([&](arith::CmpFOp cmpOp) { + auto pred = cmpOp.getPredicate(); + if (pred == arith::CmpFPredicate::OEQ || + pred == arith::CmpFPredicate::ONE || + pred == arith::CmpFPredicate::UEQ || + pred == arith::CmpFPredicate::UNE) { + return WalkResult::advance(); + } else if (pred == arith::CmpFPredicate::OGT || + pred == arith::CmpFPredicate::OLT || + pred == arith::CmpFPredicate::UGT || + pred == arith::CmpFPredicate::ULT) { + cmpFOp = cmpOp; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + cmpFOp->moveBefore(block, block->getOperations().begin()); + } else if (isa(valueType)) { + arith::CmpIOp cmpIOp; + block->walk([&](arith::CmpIOp cmpOp) { + auto pred = cmpOp.getPredicate(); + if (pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) { + return WalkResult::advance(); + } else if (pred == arith::CmpIPredicate::sgt || + pred == arith::CmpIPredicate::slt || + pred == arith::CmpIPredicate::ugt || + pred == arith::CmpIPredicate::ult) { + if (cmpOp.getLhs() == block->getArgument(0) && + cmpOp.getRhs() == block->getArgument(2)) { + cmpIOp = cmpOp; + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + cmpIOp->moveBefore(block, block->getOperations().begin()); + } + + TypedAttr valueAttr; + if (isa(valueType)) { + valueAttr = rewriter.getFloatAttr(valueType, T::getBaseReductionValue()); + } else if (isa(valueType)) { + // TODO: support other type of int + valueAttr = + rewriter.getIntegerAttr(valueType, T::getBaseReductionIntValue()); + } + + auto valuesAccBaseVal = + rewriter.create(loc, valueType, valueAttr); + + auto indexType = elemTypes[1]; + auto indicesAccBaseVal = rewriter.create( + loc, indexType, rewriter.getIntegerAttr(indexType, -1)); + + auto valueResultType = dyn_cast(op.getType(0)); + const auto isScalarReduce = valueResultType == nullptr; + SmallVector reductionResultShape{ + isScalarReduce ? SmallVector{} + : SmallVector(valueResultType.getShape())}; + + SmallVector outputs{ + getInitTensor(rewriter, reductionResultShape, valuesAccBaseVal, loc), + getInitTensor(rewriter, reductionResultShape, indicesAccBaseVal, loc)}; + + auto linalgOp = rewriter.create( + loc, adaptor.getOperands(), outputs, + SmallVector{adaptor.getAxis()}, + [&](OpBuilder &b, Location loc, ValueRange inputs) { + assert(inputs.size() == 4); + + auto tritonReduceBlock = op.getBody(); + IRMapping mapping; + mapping.map(tritonReduceBlock->getArguments(), inputs); + + for (auto &op : tritonReduceBlock->without_terminator()) { + b.clone(op, mapping); + } + + auto tritonYield = tritonReduceBlock->getTerminator(); + auto results = + llvm::map_to_vector(tritonYield->getOperands(), [&](Value val) { + return mapping.lookup(val); + }); + b.create(loc, results); + }); + + // before we rewrite the argmax reduce op, we know it has return value + // so addReduceWithIndexAttrIfNeeded won't fail + // but ignoring it will lead to compiling failure + auto logicalResult = addReduceWithIndexAttrIfNeeded(rewriter, linalgOp); + + if (isScalarReduce) { + SmallVector reduceResults{ + rewriter.create( + loc, valueType, linalgOp.getResults()[0], ValueRange{}), + rewriter.create( + loc, indexType, linalgOp.getResults()[1], ValueRange{})}; + rewriter.replaceOp(op, reduceResults); + } else { + rewriter.replaceOp(op, linalgOp); + } + } +}; + +class ArgMinConverter : public ArgMinMaxBaseConverter { +public: + static LogicalResult matchComparisonResult(Value currValue, Value currIndex, + Value reduceValue, + Value reduceIndex, + mlir::Block::iterator &it, + Value &comparisonResult); + + static float getBaseReductionValue(); + + static int8_t getBaseReductionIntValue(); + + ArgMinConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} +}; + +class ArgMaxConverter : public ArgMinMaxBaseConverter { +public: + static LogicalResult matchComparisonResult(Value currValue, Value currIndex, + Value reduceValue, + Value reduceIndex, + mlir::Block::iterator &it, + Value &comparisonResult); + + static float getBaseReductionValue(); + + static int8_t getBaseReductionIntValue(); + + ArgMaxConverter(MLIRContext *context) : ArgMinMaxBaseConverter(context) {} +}; + +} // namespace TTOpConverters + +#endif \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h new file mode 100644 index 000000000..764544dca --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/BlockPtrAnalysis.h @@ -0,0 +1,254 @@ +#ifndef TRITON_ANALYSIS_BLOCKPTRANALYSIS_H +#define TRITON_ANALYSIS_BLOCKPTRANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +#include +namespace mlir { + +class ConversionPatternRewriter; + +namespace triton { + +enum class MemAccVal { Undefined = 0, StrucMemAcc = 1, UnstrucMemAcc = 2 }; + +struct MemAccType { + + MemAccVal value; + + explicit constexpr MemAccType(MemAccVal v = MemAccVal::Undefined) + : value(v) {} + + constexpr operator MemAccVal() const { return value; } + explicit operator bool() = delete; + + constexpr bool isUndefined() const { return value == MemAccVal::Undefined; } + constexpr bool isStructured() const { + return value == MemAccVal::StrucMemAcc; + } + constexpr bool isUnstructured() const { + return value == MemAccVal::UnstrucMemAcc; + } + + void merge(MemAccType &other) { + this->value = (this->value > other.value) ? this->value : other.value; + } + + std::string_view toString() const { + static constexpr std::string_view names[] = {"Undefined", "StrucMemAcc", + "UnstrucMemAcc"}; + return names[static_cast(value)]; + } +}; + +class BlockData { +public: + SmallVector &getOffsetsRef(); + SmallVector &getSizesRef(); + SmallVector &getStridesRef(); + Value &getSourceRef(); + OpFoldResult &getScalarRef(); + Type &getResElemTyRef(); + MemAccType &getMemAccTypeRef(); + + SmallVector getOffsets() const; + SmallVector getSizes() const; + SmallVector getStrides() const; + Type getResElemTy() const; + OpFoldResult getOffset(int) const; + OpFoldResult getSize(int) const; + OpFoldResult getStride(int) const; + OpFoldResult getScalar() const; + Value getSource() const; + MemAccType getMemAccType() const; + + bool isScalar() const; + bool isEmpty() const; + bool hasSource() const; + bool hasResElemTy() const; + void removeSource(); + + int64_t getRank() const; + MemRefType getResultMemrefType(int64_t offset, + ArrayRef resultShape) const; + + void addBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter); + void mulBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter); + void divBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter); + + memref::ReinterpretCastOp createCastOp(ArrayRef resultShape, + const Location &loc, + OpBuilder &builder) const; + + void setResElemTy(const Type &); + void setSource(const Value &); + void setScalar(const OpFoldResult &); + void setOffsets(const SmallVector &); + void setStrides(const SmallVector &); + void setSizes(const SmallVector &); + void setMemAccTy(const MemAccType &); + void setMemAccVal(const MemAccVal); + + void dump() const; + +private: + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + Value source; + // `Scalar` is a shortcut used when the entire blockdata describes a single + // scalar value + OpFoldResult scalar; + Type resElemTy; + MemAccType memAccTy; + + // Accumulate offsets of each dimension in BlockData to get a total offset + // from source ptr, which is used in memref::ReinterpretCastOp + OpFoldResult inferBlockOffset(const Location &loc, OpBuilder &builder) const; +}; + +class BlockDataParser { +public: + static Value getScalarMemRef(Value ptr, Value memref, const Location &loc, + ConversionPatternRewriter &rewriter); + + static void parse(Value operand, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseAdd(arith::AddIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseMul(arith::MulIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseDiv(arith::DivSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseRem(arith::RemSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseUnrealizedCast(UnrealizedConversionCastOp op, BlockData &data, + const Location &loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseMakeRange(triton::MakeRangeOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseExpandDims(triton::ExpandDimsOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseBitcast(triton::BitcastOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseExtSI(arith::ExtSIOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseBroadcast(triton::BroadcastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseSplat(triton::SplatOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseConstSplat(arith::ConstantOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + template + static std::enable_if_t || + std::is_same_v> + parseTensorPtr(T op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseAddPtr(triton::AddPtrOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseExtractSlice(tensor::ExtractSliceOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void + parseReinterpretCast(memref::ReinterpretCastOp op, BlockData &data, + const Location &loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void parseReduce(triton::ReduceOp op, BlockData &data, + const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + + static void rewriteAddPtr(triton::AddPtrOp op, + triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known); + + static void rewriteMakeTensorPtrOp(triton::MakeTensorPtrOp op, + Value base, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known); + + static void rewriteAdvanceOp(triton::AdvanceOp op, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known); + + static void + rewriteYieldOp(scf::YieldOp op, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseSet &blockArgIdxSet, ArrayRef iterArgIdxMap, + const llvm::SmallDenseMap &known); + + /// @param known is mainly designed for `rewriteLoop`, and is just non-const in + /// `rewriteLoop`, `rewriteAddPtr` and `rewriteAdvance` + static void rewriteLoopOp(LoopLikeOpInterface op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known); + + static void rewriteAddPtrToUnstrucMemAcc(triton::AddPtrOp op, + triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, + BlockData &data); +}; + +template +void parseIndirectLoad(OpTy op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known); + +} // namespace triton + +} // namespace mlir + +#endif \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt b/third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt new file mode 100644 index 000000000..57914a554 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinalg) +add_public_tablegen_target(TritonToLinalgConversionPassIncGen) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h new file mode 100644 index 000000000..b1c4c7601 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/ConversionPatterns.h @@ -0,0 +1,111 @@ +#ifndef CONVERSIONPATTERNS_H +#define CONVERSIONPATTERNS_H + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include +#include +#include + +using namespace mlir; +using namespace triton; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +static Value getScalarValue(Value operand, Location loc, + ConversionPatternRewriter &rewriter) { + SmallVector ops; + + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + rewriter, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +static SmallVector getNParallelLoopsAttrs(unsigned n) { + return SmallVector(n, utils::IteratorType::parallel); +} + +// for IntLike and FloatLike types +static std::optional getBitWidth(Type a) { + if (auto type = dyn_cast(a)) { + auto elementType = type.getElementType(); + if (elementType.isIntOrFloat()) { + return type.getElementType().getIntOrFloatBitWidth(); + } + return std::nullopt; + } + + if (a.isIntOrFloat()) { + return a.getIntOrFloatBitWidth(); + } + return std::nullopt; +} +#endif // CONVERSIONPATTERNS_H diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/DescriptorConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/DescriptorConverter.h new file mode 100644 index 000000000..cfb998798 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/DescriptorConverter.h @@ -0,0 +1,48 @@ +#ifndef TRITON_ADAPTER_DESCRIPTORCONVERTER_H +#define TRITON_ADAPTER_DESCRIPTORCONVERTER_H + +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +namespace DescriptorConverter { +using namespace mlir; +using namespace triton; + +struct Descriptor { + Value base; + SmallVector shape; + SmallVector strides; +}; + +bool hasATensorDescriptorType(mlir::TypeRange types); + +class DescriptorLoadConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(triton::DescriptorLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class DescriptorStoreConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(triton::DescriptorStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // end of namespace DescriptorConverter + +#endif // TRITON_ADAPTER_DESCRIPTORCONVERTER_H diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h new file mode 100644 index 000000000..3d1bd931a --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/FunctionConverter.h @@ -0,0 +1,38 @@ +#ifndef TRITON_ADAPTER_FUNCTIONCONVERTER_H +#define TRITON_ADAPTER_FUNCTIONCONVERTER_H + +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace FunctionConverter { +using namespace mlir; +using namespace triton; + +class GetProgramIDConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; + +public: + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class GetNumProgramsConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static uint32_t constexpr LAUNCH_GRID_RANK = + getMaxEnumValForProgramIDDim() + 1; + +public: + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace FunctionConverter +#endif \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h new file mode 100644 index 000000000..57ae225c0 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h @@ -0,0 +1,222 @@ +#ifndef TRITON_ADAPTER_LOADSTORECONVERTER_H +#define TRITON_ADAPTER_LOADSTORECONVERTER_H + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Arith/Utils/Utils.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace LoadStoreConverter { + +using namespace mlir; +using namespace triton; + +class AddPtrConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class LoadConverter : public OpConversionPattern { +private: + LogicalResult toTensorAndReplace(triton::LoadOp &op, + RankedTensorType &tensorType, + memref::AllocOp &allocOp, + bool mayImplicitTransposeWithLastAxis, + const Location &loc, + ConversionPatternRewriter &rewriter) const; + + LogicalResult checkModifiedByAddPtrConverter(triton::LoadOp &op) const; + + LogicalResult + continueModifyFromAddPtrConverter(triton::LoadOp &op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + + void + fillTensorWithOtherForMaskScenario(Value other, memref::AllocOp localMem, + ArrayRef maskDim, + ConversionPatternRewriter &rewriter) const; +public: + explicit LoadConverter(MLIRContext *context); + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +// tempate class's impl must in header file +template +class LoadStoreCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Value ptrVal = op.getPtr(); + Type ptrTy = ptrVal.getType(); + auto ptrDefOp = ptrVal.getDefiningOp(); + if (isa(ptrVal)) + return failure(); + + if (!isTensorPointerType(ptrTy) && + !isa_and_nonnull(ptrDefOp)) { + if (isa(ptrDefOp)) { + auto castOp = cast(ptrDefOp); + auto castSrc = castOp.getSrc(); + if (!isa(castSrc)) { + auto castSrcDefOp = castSrc.getDefiningOp(); + if (isa(castSrcDefOp)) { + return rewriter.notifyMatchFailure( + op, "BitcastCanonicalizer handles addptr->bitcast->load!"); + } + } + } + + Type zeroTy = getI32SameShape(ptrTy); + Value zeroVal = + createScalarOrSplatConstant(rewriter, op.getLoc(), zeroTy, 0); + Value addptrVal = rewriter.create(op.getLoc(), ptrTy, + ptrVal, zeroVal); + rewriter.modifyOpInPlace( + op, [&]() { op->replaceUsesOfWith(ptrVal, addptrVal); }); + return success(); + } + return failure(); + } +}; + +class ScalarStoreCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const override; +}; + +class StoreConverter : public OpConversionPattern { +public: + explicit StoreConverter(MLIRContext *context); + + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ScalarAtomicRMWCanonicalizer + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override; +}; + +class ScalarAtomicCASCanonicalizer + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicCASOp op, + PatternRewriter &rewriter) const override; +}; + +class AtomicCASConverter : public OpConversionPattern { +public: + explicit AtomicCASConverter(MLIRContext *context) : + OpConversionPattern(context) {} + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class AtomicRMWConverter : public OpConversionPattern { +private: + Value createAtomicBinaryOps(OpBuilder &builder, Location loc, + triton::AtomicRMWOp op, Type elementType, + Value lhs, Value rhs) const { + auto rmwOp = op.getAtomicRmwOp(); + + // it has been confirmed in AtomicRMWConverter::matchAndRewrite + // that the ptr of op is of MemRefType + Value binaryOp; + if (rmwOp == triton::RMWOp::FADD) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::ADD) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::XOR) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::OR) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::AND) { + binaryOp = builder.create(loc, lhs, rhs); + } else if (rmwOp == triton::RMWOp::MAX) { + // Max/Min only support f32/i32 for now + // Other type is not supported because of semantic.py + if (isa(elementType)) { + binaryOp = builder.create(loc, lhs, rhs); + } else { + binaryOp = builder.create(loc, lhs, rhs); + } + } else if (rmwOp == triton::RMWOp::MIN) { + if (isa(elementType)) { + binaryOp = builder.create(loc, lhs, rhs); + } else { + binaryOp = builder.create(loc, lhs, rhs); + } + } else if (rmwOp == triton::RMWOp::XCHG) { + binaryOp = rhs; + } else { + op.emitOpError("unsupported atomic RMW operation: "); + llvm_unreachable( + "Not implemented. Support fadd, add, max, min for now !"); + } + return binaryOp; + } + + // used when handling scalar + // to verify whether we need to handle this scalar + bool isConstantMaskTrue(Value mask) const { + if (auto denseAttr = + mask.getDefiningOp()->getAttrOfType("value")) { + auto eleType = denseAttr.getType().getElementType(); + if (isa(eleType) && + cast(eleType).getWidth() == 1) { + auto values = denseAttr.getValues(); + return values[0]; + } + } + return false; + } + + DenseSet softwareAtomicKinds = { + triton::RMWOp::AND, triton::RMWOp::OR, triton::RMWOp::XOR}; + +public: + explicit AtomicRMWConverter(MLIRContext *context); + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class AtomicMaxMinCanonicalizer : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const override; +}; + +} // namespace LoadStoreConverter +#endif \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h new file mode 100644 index 000000000..56e9362b9 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/MaskAnalysis.h @@ -0,0 +1,126 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_MASKANALYSIS_H +#define TRITON_ANALYSIS_MASKANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +namespace mlir { + +// this class helps build Operations +class OpBuilder; + +namespace triton { +// use to decode the pattern in a mask used for load and store + +class MaskState { +public: + OpFoldResult start; + OpFoldResult end; + SmallVector dims; + SmallVector offsets; + OpFoldResult scalar; + + int64_t getRank() const { + assert(dims.size() == offsets.size() && "dims and offsets rank mismatch!"); + return dims.size(); + } + + bool isEmpty() const { return getRank() == 0 && !scalar && !start && !end; } + + bool isMask() const { + return !start && !end && !scalar && dims.size() != 0 && offsets.size() != 0; + } + + // parse value recursively + LogicalResult parse(Value operand, const Location &loc, OpBuilder &builder); + + tensor::ExtractSliceOp getExtractSlice(Value source, const Location &loc, + OpBuilder &builder) const; + + tensor::InsertSliceOp getInsertSlice(Value source, Value dest, + const Location &loc, + OpBuilder &builder) const; + + memref::SubViewOp getSubview(Value source, const Location &loc, + OpBuilder &builder) const; + + void eraseInsertedOps(Operation *rawOp, PatternRewriter &rewriter); + +private: + LogicalResult addStateScalar(const MaskState &state, + const OpFoldResult scalar, const Location &loc, + OpBuilder &builder); + + LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, + const Location &loc, OpBuilder &builder); + + LogicalResult divStateScalar(const MaskState &state, + const OpFoldResult scalar, const Location &loc, + OpBuilder &builder); + + LogicalResult divStates(const MaskState &lhsState, const MaskState &rhsState, + const Location &loc, OpBuilder &builder); + + // Helper function to handle operator `and` both mask state + LogicalResult minStates(const MaskState &lhsState, const MaskState &rhsState, + const Location &loc, OpBuilder &builder); + + // Helper functions to parse values to populate MaskState + + LogicalResult parseConstant(arith::ConstantOp constOp, const Location &loc, + OpBuilder &builder); + + // Operand is an integer scalar + LogicalResult parseIntScalar(Value scalar, const Location &loc, + OpBuilder &builder); + + // TODO + LogicalResult parseAdd(arith::AddIOp addOp, const Location &loc, + OpBuilder &builder); + + // operand is the result of divsi + LogicalResult parseDiv(arith::DivSIOp divOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of andi + LogicalResult parseAnd(arith::AndIOp andOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of cmpi, necessary method to fuse scalar, start and + // end into dims and offset + LogicalResult parseCmp(arith::CmpIOp cmpOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of make_range + LogicalResult parseMakeRange(triton::MakeRangeOp rangeOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of broadcast + LogicalResult parseBroadcast(triton::BroadcastOp broadcastOp, + const Location &loc, OpBuilder &builder); + + // Operand is the result of splat + LogicalResult parseSplat(triton::SplatOp splatOp, const Location &loc, + OpBuilder &builder); + + // Operand is the result of expand_dims + LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location &loc, OpBuilder &builder); +}; + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h new file mode 100644 index 000000000..6364d12e5 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_ADAPTER_TRITON_TO_LINALG_CONVERSION_PASSES_H +#define TRITON_ADAPTER_TRITON_TO_LINALG_CONVERSION_PASSES_H + +#include "TritonToLinalgPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "ascend/triton-adapter/include/TritonToLinalg/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_TRITON_TO_LINALG_CONVERSION_PASSES_H \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td new file mode 100644 index 000000000..10fedf918 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/Passes.td @@ -0,0 +1,22 @@ +#ifndef TRITON_TO_LINALG_CONVERSION_PASSES +#define TRITON_TO_LINALG_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> { + let summary = "Convert Triton to Linalg dialect"; + let constructor = "triton::createTritonToLinalgPass()"; + let options = [ + Option<"globalKernel", "global-kernel", + "bool", /*default*/"true", + "generate a global kernel">, + Option<"namedOps", "named-ops", + "bool", /*default*/"false", + "use linalg named ops instead of linalg.generic">, + Option<"enableNd2nzOnVector", "enable-nd2nz-on-vector", + "bool", /*default*/"false", + "enable nd2nz on vector"> + ]; +} + +#endif // TRITON_TO_LINALG_CONVERSION_PASSES diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h new file mode 100644 index 000000000..f22be76aa --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h @@ -0,0 +1,541 @@ +#ifndef TRITON_ADAPTER_TRITONOPCONVERTER_H +#define TRITON_ADAPTER_TRITONOPCONVERTER_H + +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-to-linalg" + +namespace TTOpConverters { +using namespace mlir; +using namespace triton; + +/* +Convert `tt.precise_div` operation to `arith.divf` operation. +tensor_x / tensor_y + +```ttir + %11 = tt.precise_divf %7, %10 : tensor<100xf32> +``` + +converts to: + +```mlir + %11 = arith.divf %7, %10 : tensor<100xf32> +``` +*/ +struct PreciseDivConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::PreciseDivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/* + * Move tt.bitcast to a previous location if tt.bitcast is not directly applied + * on function arguments + */ +class BitcastCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::BitcastOp bitcastOp, + PatternRewriter &rewriter) const override; +}; + +template +class ScalarMathCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MathOp op, + PatternRewriter &rewriter) const override { + if (op->getNumResults() != 1) { + return rewriter.notifyMatchFailure( + op, "ScalarMathCanonicalizer expects single scalar output."); + } + if (!op->getResult(0).getType().isIntOrIndexOrFloat()) { + return rewriter.notifyMatchFailure( + op, "ScalarMathCanonicalizer handles scalar load scene."); + } + if (auto linalgOp = op->template getParentOfType()) { + return rewriter.notifyMatchFailure( + op, "ScalarMathCanonicalizer handles op not within tt.reduce."); + } + if (auto linalgOp = op->template getParentOfType()) { + return rewriter.notifyMatchFailure( + op, "ScalarMathCanonicalizer handles op not within tt.scan."); + } + auto loc = op.getLoc(); + llvm::SmallVector inputs; + for (auto input : op->getOperands()) { + auto blkTy = RankedTensorType::get({(int64_t)1}, input.getType()); + auto inputSplat = rewriter.create(loc, blkTy, input); + inputs.push_back(inputSplat.getResult()); + } + auto blkOp = rewriter.create(loc, inputs); + Value offset = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto extractOp = + rewriter.create(loc, blkOp.getResult(), offset); + rewriter.replaceOp(op, extractOp); + return success(); + } +}; + +/* + * Rewrite tt.make_tensor_ptr with non-contiguous order to + * tt.make_tensor_ptr + tt.load + tt.trans. + */ +class MakeTensorPtrCanonicalizer + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::MakeTensorPtrOp op, + PatternRewriter &rewriter) const override; +}; + +class ReduceSingleCanonicalizer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::ReduceOp reduceOp, + PatternRewriter &rewriter) const override; +}; + +class DenseConstantConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class MakeRangeConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class SplatConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ReshapeConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ExpandDimsConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class ClampFConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class BroadcastConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +template +class ReductionOpBaseConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(OpTy op, + typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto sourceType = + cast(adaptor.getOperands().front().getType()); + assert(sourceType.hasRank() && "Expected input is ranked"); + + int64_t axis = op.getAxis(); + assert(axis >= 0 && axis < sourceType.getRank() && "Expected reduction axis is within operand's rank"); + + auto reductionOps = this->getRedOps(op); + if (reductionOps.size() == 1) { + return this->convertToTargetOp(op, adaptor, rewriter); + } + return this->convertToTargetOpExtended(op, adaptor, rewriter); + } + +protected: + llvm::SmallVector getRedOps(OpTy redOp) const { + auto redBody = redOp.getBody(); + return llvm::map_to_vector(redBody->without_terminator(), + [](Operation &op) { return &op; }); + } + + arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter, + Operation *redOp, + Type constantType) const { + const int64_t bitWidth = constantType.getIntOrFloatBitWidth(); + + auto attr = llvm::TypeSwitch(redOp) + .Case([&](arith::AddFOp) { + return rewriter.getFloatAttr(constantType, 0.f); + }) + .Case([&](arith::AddIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](arith::MulFOp) { + return rewriter.getFloatAttr(constantType, 1.f); + }) + .template Case([&](auto) { + return rewriter.getFloatAttr( + constantType, -std::numeric_limits::infinity()); + }) + .template Case([&](auto) { + return rewriter.getFloatAttr( + constantType, std::numeric_limits::infinity()); + }) + .Case([&](arith::MinSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxIntN(bitWidth)); + }) + .Case([&](arith::MinUIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxUIntN(bitWidth)); + }) + .Case([&](arith::MaxSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::minIntN(bitWidth)); + }) + .Case([&](arith::MaxUIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](arith::OrIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](arith::AndIOp) { + return rewriter.getIntegerAttr(constantType, 1); + }) + .Case([&](arith::XOrIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not supported yet"); + return nullptr; + }); + + return rewriter.create(redOp->getLoc(), constantType, attr); + } + + bool requiresF32Conversion(const Type elemType, Operation *redOp) const { + return isa(elemType) && + elemType.getIntOrFloatBitWidth() < + Float32Type::get(elemType.getContext()).getWidth() && + (isa(redOp) || isa(redOp)); + } + + Value getRedElement( + Value lhs, Value rhs, const Location loc, Operation *redOp, OpBuilder &b, + const bool convertLhsToF32Precision) const { + return llvm::TypeSwitch(redOp) + .template Case([&](auto redOp) { + if (convertLhsToF32Precision) { + lhs = b.create(loc, Float32Type::get(b.getContext()), + lhs); + } + return b.create(loc, lhs, rhs); + }) + .template Case( + [&](auto redOp) { return b.create(loc, lhs, rhs); }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return nullptr; + }); + } + + virtual bool isReductionOpSupported(Operation *redOp) const = 0; + + virtual LogicalResult + convertToTargetOp(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const = 0; + + virtual LogicalResult + convertToTargetOpExtended(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const = 0; +}; + +class ReduceConverter : public ReductionOpBaseConverter { +public: + explicit ReduceConverter(MLIRContext *context) + : ReductionOpBaseConverter(context) {} + + using ReductionOpBaseConverter::ReductionOpBaseConverter; + +protected: + bool isReductionOpSupported(Operation *redOp) const override; + + LogicalResult + convertToTargetOp(triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + + LogicalResult + convertToTargetOpExtended(triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + +}; + +class ScanConverter : public ReductionOpBaseConverter { +public: + explicit ScanConverter(MLIRContext *context) + : ReductionOpBaseConverter(context) {} + + using ReductionOpBaseConverter::ReductionOpBaseConverter; + +protected: + bool isReductionOpSupported(Operation *redOp) const override; + + LogicalResult + convertToTargetOp(triton::ScanOp op, typename triton::ScanOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + + LogicalResult + convertToTargetOpExtended(triton::ScanOp op, typename triton::ScanOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + +}; + +class ExternElementwiseClOpConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class UnrealizedCastConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class JoinConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class SplitConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class CatConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class GatherConverter : public OpConversionPattern { +private: + static constexpr llvm::StringRef gatherFuncNameBase = "triton_gather"; + +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class YieldConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +template || + std::is_same_v>> +class LoopConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(LoopOpTy op, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallDenseMap known; + + op->removeAttr("UnhandledLoopOp"); + BlockDataParser::rewriteLoopOp(op, rewriter, known); + return success(); + } +}; + +class AdvanceConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class MakeTensorPtrConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + explicit MakeTensorPtrConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class TransposeConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class BitcastConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class TritonMulhiuiConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::MulhiUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class TritonPreciseSqrtConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::PreciseSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class DeviceAssertConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static constexpr llvm::StringRef printFuncNameBase = "triton_assert"; + static constexpr llvm::StringRef msgAttrName = "msg"; + +public: + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class DevicePrintConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static constexpr llvm::StringRef printFuncNameBase = "triton_print"; + static constexpr llvm::StringRef prefixAttrName = "prefix"; + static constexpr llvm::StringRef hexAttrName = "hex"; + +public: + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct MatmulConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + + +struct SortOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(triton::SortOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + + +struct DotScaledConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotScaledOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +class PtrToIntConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // end of namespace TTOpConverters + +#endif diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h new file mode 100644 index 000000000..c47800035 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h @@ -0,0 +1,88 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ADAPTER_CONVERSION_TRITONTOLINALG_H +#define TRITON_ADAPTER_CONVERSION_TRITONTOLINALG_H + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#define GEN_PASS_CLASSES +#include "ascend/triton-adapter/include/TritonToLinalg/Passes.h.inc" + +extern int nd2nzFlag; +extern bool existDotFlag; + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToLinalgPass(); + +enum TensorKind { NONE = -1, INPUT = 0, OUTPUT = 1, INPUT_OUTPUT = 2 }; + +} // namespace triton +} // namespace mlir + +namespace { + +using namespace mlir; +using namespace triton; +const std::string globalKernelAttr = "global_kernel"; +const std::string kernelMixModeName = "mix_mode"; +const unsigned INT_BIT_WIDTH = 32; +const unsigned SET_INIT_SIZE = 16; + +class TritonTypeConverter : public mlir::TypeConverter { +public: + explicit TritonTypeConverter(); +}; + +class TritonToLinalgPass : public TritonToLinalgBase { + + static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; + static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = + LAUNCH_GRID_RANK * 2; + +private: + // grid构造 num_programs 3维, program_id 3维 + // remember 'xxxOp' is usually a Pointer, so that we can change target memory + // without giving a reference argument + void addProgramInfo(triton::FuncOp func, bool globalKernel); + + template + void addTensorKindToArguments(OpTy op, triton::FuncOp func, TensorKind tensorKind); + + void convertTTFunc(triton::FuncOp func, const bool existDot); + + LogicalResult convertMultipleBlockControlFlow(Operation *funcOp, + OpBuilder &builder); + // 处理嵌套的if/else + scf::IfOp transformNestedIfElse(Operation &nestedBranch, OpBuilder &builder); + + void addDynamicLegal(ConversionTarget &target, + TritonTypeConverter &tritonTypeConverter); + + void + populateTritonToLinalgCanonicalizationPatterns(RewritePatternSet &patterns); + + void populateTritonToLinalgConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + unsigned int launchGridRank); + + LogicalResult processDescriptorOperations(ModuleOp moduleOp); + +public: + void getDependentDialects(DialectRegistry ®istry) const override; + + void runOnOperation() override; +}; +} // namespace + +#endif // TRITON_ADAPTER_CONVERSION_TRITONTOLINALG_H diff --git a/third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h new file mode 100644 index 000000000..e2727fa4c --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToLinalg/UseAnalysis.h @@ -0,0 +1,128 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_ANALYSIS_USEANALYSIS_H +#define TRITON_ANALYSIS_USEANALYSIS_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" + +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +enum class UseType { + Undefined, // Initial state + DataUse, // value used for tensor computation only + MetaUse, // value used for metadata only + MixUse // value used for both tensor computation and metadata +}; + +struct UseInfo : public dataflow::AbstractSparseLattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UseInfo) + using AbstractSparseLattice::AbstractSparseLattice; + + // Lattice state transfer function + ChangeResult meetUseType(const UseType &other) { + if (other == UseType::Undefined) { + return ChangeResult::NoChange; + } + + switch (type) { + case UseType::Undefined: + type = other; + return ChangeResult::Change; + case UseType::DataUse: + case UseType::MetaUse: + if (type == other) { + return ChangeResult::NoChange; + } else { + type = UseType::MixUse; + return ChangeResult::Change; + } + case UseType::MixUse: + return ChangeResult::NoChange; + default: + llvm_unreachable("bad type"); + } + } + + ChangeResult meet(const AbstractSparseLattice &other) override { + auto rhs = reinterpret_cast(&other); + return meetUseType(rhs->type); + } + + void print(raw_ostream &os) const override { + switch (type) { + case UseType::DataUse: + os << "DataUse"; + break; + case UseType::MetaUse: + os << "MetaUse"; + break; + case UseType::MixUse: + os << "MixUse"; + break; + default: + os << "Undefined"; + } + } + + UseType type = UseType::Undefined; +}; + +class UseAnalysis : public dataflow::SparseBackwardDataFlowAnalysis { +public: + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + +#if LLVM_VERSION_MAJOR >= 20 + LogicalResult visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; +#else + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; +#endif + + void visitBranchOperand(OpOperand &operand) override { return; } + + void visitCallOperand(OpOperand &operand) override { return; } + + void setToExitState(UseInfo *lattice) override { + lattice->type = UseType::Undefined; + } + +private: + void propagateUse(UseInfo *lattice, const UseType &type) { + auto changed = lattice->meetUseType(type); + propagateIfChanged(lattice, changed); + } + + void propagateResults(UseInfo *lattice, ArrayRef results) { + auto changed = ChangeResult::NoChange; + for (auto result : results) { + changed |= lattice->meet(*result); + } + propagateIfChanged(lattice, changed); + } +}; + +class MetaUseEraser : public RewritePattern { +public: + MetaUseEraser(MLIRContext *context); + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final; +}; + +LogicalResult runUseAnalysis(triton::FuncOp &funcOp); + +} // namespace triton + +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONTOAFFINE_TRITONUSEANALYSIS_H diff --git a/third_party/ascend/triton-adapter/include/TritonToUnstructure/BubbleUpOperation.h b/third_party/ascend/triton-adapter/include/TritonToUnstructure/BubbleUpOperation.h new file mode 100644 index 000000000..ee975662a --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToUnstructure/BubbleUpOperation.h @@ -0,0 +1,131 @@ +#pragma once + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" + +#define GEN_PASS_DECL_BUBBLEUPOPERATION +#include "../../include/TritonToUnstructure/Passes.h.inc" + +#define GEN_PASS_DEF_BUBBLEUPOPERATION +#include "../../include/TritonToUnstructure/Passes.h.inc" + +namespace mlir { +namespace triton { + +std::unique_ptr> +createBubbleUpOperationPass(const BubbleUpOperationOptions &options = {}); + +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace triton; + +class BubbleUpExtract : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit BubbleUpExtract(MLIRContext *context, bool enableAggressiveMode); + + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const override; + +private: + Value createExtractOp(Value value, ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template + void bubbleUpIntBinaryOp(Operation *op, BinOpTy binOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template + void bubbleUpFloatBinaryOp(Operation *op, BinOpTy binOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template + void bubbleUpOperation(Operation *op, ParentOpTy parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const = delete; + + template <> + void bubbleUpOperation(Operation *op, arith::ExtSIOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + + template <> + void bubbleUpOperation(Operation *op, arith::CmpIOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template <> + void bubbleUpOperation(Operation *op, + arith::TruncFOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + + template <> + void bubbleUpOperation(Operation *op, + arith::ExtFOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + + template <> + void bubbleUpOperation(Operation *op, + arith::FPToSIOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template <> + void bubbleUpOperation(Operation *op, + arith::SIToFPOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template <> + void + bubbleUpOperation(Operation *op, triton::ClampFOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template <> + void bubbleUpOperation(Operation *op, arith::CmpFOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template <> + void bubbleUpOperation(Operation *op, + triton::BroadcastOp parentOp, + ArrayRef indices, + Location loc, + PatternRewriter &rewriter) const; + template <> + void bubbleUpOperation(Operation *op, + triton::ExpandDimsOp parentOp, + ArrayRef indices, + Location loc, + PatternRewriter &rewriter) const; + template <> + void bubbleUpOperation(Operation *op, + triton::SplatOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template <> + void bubbleUpOperation(Operation *op, + triton::MakeRangeOp parentOp, + ArrayRef indices, + Location loc, + PatternRewriter &rewriter) const; + template <> + void bubbleUpOperation(Operation *op, math::FloorOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + template <> + void bubbleUpOperation(Operation *op, math::CeilOp parentOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const; + + bool enableAggressiveMode; +}; + +class BubbleUpOperationPass + : public ::impl::BubbleUpOperationBase { +public: + explicit BubbleUpOperationPass(const BubbleUpOperationOptions &options); + void runOnOperation() override; +}; \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToUnstructure/CMakeLists.txt b/third_party/ascend/triton-adapter/include/TritonToUnstructure/CMakeLists.txt new file mode 100644 index 000000000..ecdd8d93b --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToUnstructure/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToUnstructure) +add_public_tablegen_target(TritonToUnstructureConversionPassIncGen) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h b/third_party/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h new file mode 100644 index 000000000..2b922c2b8 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h @@ -0,0 +1,228 @@ +#ifndef TRITON_ANALYSIS_OFFSETANALYSIS_H +#define TRITON_ANALYSIS_OFFSETANALYSIS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace triton { + +struct PtrOffsetInfo { + /** + Possible status of the ptr offset: + - ScalarLike: + - Tensor's elements are all the same such as [[2.0,2.0,2.0],[2.0,2.0,2.0]] + - Constant integer or floating-point such as 2, 2.0, and `load + tensor<1xptr>` + - Unstructured: + - Not a `ScalarLike` ptr offset + - Or satisfy any below conditions: + - Incontinuous stride such as + - `muli [0,1,2,3] [0,1,2,3]` => [0,1,4,9] + - `divsi [9,8,7] [3,2,1]` => [3,4,7] + - `minsi [3,4,5] [5,4,3]` => [3,4,3] + - From non-`scalarLike` floating point element type such as + - `fptosi [1.0,2.0,3.0]` => [1,2,3] + - Compilation time unknown value + - `load %ptr, %offset` => %value + - Structured: + - orthongonal to `Unstructured` + - if PtrOffsetInfo isn't `Unstructured`, it is `Structured` + + In short: + ScalarLike ⊆ Structured + Unstructured = {x| x ∉ Structured} + + Example: + ``` + %y = sitofp %x + %z = fptosi %y + ``` + If %x is scalarLike (structured), %z will be scalar (structured) as well. + If %x is non-scalarLike structured, %z will be unstructured. + */ + +public: + explicit PtrOffsetInfo(); + PtrOffsetInfo(const PtrOffsetInfo &other); + + explicit PtrOffsetInfo(const Value &ptr); + explicit PtrOffsetInfo(ArrayRef structured); + explicit PtrOffsetInfo(const Value &ptr, bool structured); + explicit PtrOffsetInfo(const Value &ptr, ArrayRef structured); + explicit PtrOffsetInfo(const Value &ptr, const Value &offset, bool structured); + explicit PtrOffsetInfo(const Value &ptr, const Value &offset, ArrayRef structured); + + PtrOffsetInfo &operator=(const PtrOffsetInfo &other); + + Value getPtr() const; + Value getOffset() const; + bool isScalarLike() const; + bool isNegativeFlag() const; + SmallVector &getStructuredRef(); + const SmallVector &getStructured() const; + int getRank() const; + + void setPtr(const Value &ptr); + void setOffset(const Value &offset); + void setStructured(); + void setStructured(int rank); + void setUnstructured(); + void setUnstructured(int rank); + void setStructured(ArrayRef structured); + void setStructured(const PtrOffsetInfo &other); + void setScalarLike(bool scalarLike); + void setNegativeFlag(bool negativeFlag); + bool isStructured(int dim) const; + bool isStructured() const; + bool isUnstructured() const; + + void setZeroOffset(); +private: + Value ptr; + Value offset; + + bool scalarLike = false; + bool negativeFlag = false; + SmallVector structured; +}; + +PtrOffsetInfo combineInfo(const PtrOffsetInfo &lhs, const PtrOffsetInfo &rhs); + +void parse(Value operand, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseLoopRegionIterArg(LoopLikeOpInterface loopOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap, + BlockArgument regionIterArg); + +void parseArithOp(Operation *arithOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseTritonOp(Operation *tritonOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseTritonOp(Operation *tritonOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseAddPtr(triton::AddPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseSplat(triton::SplatOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +template +void parseBinaryOp(BinOpTy op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseAddI(arith::AddIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseIndexCast(arith::IndexCastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +template +void parseConstantOp(ConstOpTy dst, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseMakeRange(triton::MakeRangeOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseExtSI(arith::ExtSIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseBitcast(triton::BitcastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseLoad(triton::LoadOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseMulI(arith::MulIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseBroadcast(triton::BroadcastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseExpandDims(triton::ExpandDimsOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseClampF(triton::ClampFOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseSelect(arith::SelectOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseFPToSI(arith::FPToSIOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseSIToFP(arith::SIToFPOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseMakeTensorDesc(triton::MakeTensorDescOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseMakeTensorPtr(triton::MakeTensorPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseAdvance(triton::AdvanceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseReduce(triton::ReduceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseReduceReturn(triton::ReduceReturnOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseIf(scf::IfOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap, Value dst); + +void parseYield(scf::YieldOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseLoopOp(LoopLikeOpInterface op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap, Value dst); + +void parseExtractSlice(tensor::ExtractSliceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseExtract(tensor::ExtractOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); + +void parseIntToPtr(triton::IntToPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap); +} // namespace triton + +} // namespace mlir + +#endif \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToUnstructure/Passes.h b/third_party/ascend/triton-adapter/include/TritonToUnstructure/Passes.h new file mode 100644 index 000000000..8ca05c10b --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToUnstructure/Passes.h @@ -0,0 +1,16 @@ +#ifndef TRITON_ADAPTER_TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES_H +#define TRITON_ADAPTER_TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES_H + +#include "BubbleUpOperation.h" +#include "UnstructureConversionPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "ascend/triton-adapter/include/TritonToUnstructure/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TRITON_ADAPTER_TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES_H \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/TritonToUnstructure/Passes.td b/third_party/ascend/triton-adapter/include/TritonToUnstructure/Passes.td new file mode 100644 index 000000000..6086b4595 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToUnstructure/Passes.td @@ -0,0 +1,20 @@ +#ifndef TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES +#define TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToUnstructure : Pass<"triton-to-unstructure", "mlir::ModuleOp"> { + let summary = "Convert Triton for unstructure case"; + let constructor = "triton::createTritonToUnstructurePass()"; +} + +def BubbleUpOperation : Pass<"bubble-up-operation", "mlir::ModuleOp"> { + let summary = "Apply bubble up operation optimization"; + let constructor = "triton::createBubbleUpOperationPass()"; + let options = [ + Option<"enableAggressiveMode", "enable-aggressive-mode", "bool", "true", + "Enable aggressive bubble up operation.">, + ]; +} + +#endif // TRITON_TO_UNSTRUCTURE_CONVERSION_PASSES diff --git a/third_party/ascend/triton-adapter/include/TritonToUnstructure/UnstructureConversionPass.h b/third_party/ascend/triton-adapter/include/TritonToUnstructure/UnstructureConversionPass.h new file mode 100644 index 000000000..62a42a037 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/TritonToUnstructure/UnstructureConversionPass.h @@ -0,0 +1,109 @@ +#ifndef TRITON_ADAPTER_UNSTRUCTURECONVERSION_H +#define TRITON_ADAPTER_UNSTRUCTURECONVERSION_H + +#include "TritonToUnstructure/OffsetAnalysis.h" +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" + +#define GEN_PASS_DEF_TRITONTOUNSTRUCTURE +#include "ascend/triton-adapter/include/TritonToUnstructure/Passes.h.inc" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToUnstructurePass(); + +} // namespace triton +} // namespace mlir + +namespace { + +using namespace mlir; +using namespace triton; + +// For example, in unstructured load case +// %0 = tt.load %structured : tensor<128x128x!tt.ptr> +// %ptr_2 = tt.splat %arg1 : !tt.ptr -> tensor<128x128x!tt.ptr> +// %1 = tt.addptr %ptr_2, %0 : tensor<128x128x!tt.ptr>, +// tensor<128x128xi32> %2 = tt.load %1 : tensor<128x128x!tt.ptr> tt.store +// %output %2 : tensor<128x128x!tt.ptr> +// +// +// In this case, this will be converted to +// +// %0 = tt.load %structured : tensor<128x128x!tt.ptr> +// %1 = tensor.empty() : tensor<128x128xf32> +// %2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %1) -> +// (tensor<128x128xf32>) { +// %4 = scf.for %arg4 = %c0 to %c128 step %c1 iter_args(%arg5 = %arg3) -> +// (tensor<128x128xf32>) { +// %extracted = tensor.extract %10[%arg3, %arg5] {DiscreteMemAccess} : +// tensor<128x128xi32> %5 = arith.extsi %extracted : i32 to i64 %6 = +// tt.addptr %arg1, %5 : !tt.ptr, i64 %7 = tt.load %6 +// {DiscreteMemAccess} : tt.ptr %inserted_slice = tensor.insert_slice +// %7 into %arg5[%arg2, %arg4] [1, 1] [128, 1] {DiscreteMemAccess} : +// tensor<1x1xf32> into tensor<128x128xf32> scf.yield %inserted_slice : +// tensor<128x128xf32> +// } +// scf.yield %4 : tensor<128x128xf32> +// } +// tt.store %output %2 : tensor<128x128x!tt.ptr> +template +class UnstructuredMemAccessConverter : public OpRewritePattern { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v); + +public: + using OpRewritePattern::OpRewritePattern; + + explicit UnstructuredMemAccessConverter( + MLIRContext *context, + const llvm::DenseMap &offsetMap); + LogicalResult matchAndRewrite(MemAccOpTy op, + PatternRewriter &rewriter) const override; + +private: + Value createExtractOp(Location loc, Value value, ArrayRef iterIdx, + PatternRewriter &rewriter) const; + template + typename std::enable_if, void>::type + splatAndLoadScenario(MemAccOpTy op, int rank, + PatternRewriter &rewriter) const; + + MemAccOpTy createMemAccOp(MemAccOpTy op, Value ptrToAccess, Location loc, + ArrayRef iterIdx, + PatternRewriter &rewriter) const; + + void AddAssertForAddPtr(MemAccOpTy op, const Value &opoffset, + PatternRewriter &rewriter) const; + + const llvm::DenseMap &offsetMap; +}; + +class TritonToUnstructurePass + : public ::impl::TritonToUnstructureBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override; + + void runOnOperation() override; + +private: + void runPreparse(LoopLikeOpInterface op); + template || + std::is_same_v || + std::is_same_v || + std::is_same_v>> + void runParse(MemAccOpTy op); + llvm::DenseMap offsetMap; + llvm::DenseMap offsetMapForLoopArgs; +}; + +} // namespace + +#endif // TRITON_ADAPTER_UNSTRUCTURECONVERSION_H \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h b/third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h new file mode 100644 index 000000000..f63b7e34b --- /dev/null +++ b/third_party/ascend/triton-adapter/include/Utils/InterleaveOptimization.h @@ -0,0 +1,71 @@ +#pragma once + +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "TritonToLinalg/MaskAnalysis.h" +#include "TritonToLinalg/UseAnalysis.h" +#include "Utils/Utils.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include +#include +#include +#include +#include + +namespace mlir { +namespace triton { + +enum class IndexMode : int { EVEN_MODE = 0, ODD_MODE = 1 }; + +MemRefType expandInterleaveMemRefType(MemRefType originType); + +std::pair +recountReinterpretCastOffset(OpFoldResult originOffset, Builder &builder); + +LogicalResult +DeinterleaveStatusOptimization(triton::LoadOp op, + triton::LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter); + +LogicalResult DeinterleaveStatusWithMaskOptimization( + triton::LoadOp op, triton::LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter, MaskState &mstate, + memref::AllocOp originAllocOp); + +LogicalResult +InterleaveStatusOptimization(SmallVector materializeVec); + +LogicalResult +InterleaveStatusWithMaskOptimization(SmallVector materializeVec); + +} // namespace triton +} // namespace mlir \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/include/Utils/Utils.h b/third_party/ascend/triton-adapter/include/Utils/Utils.h new file mode 100644 index 000000000..ca37e1156 --- /dev/null +++ b/third_party/ascend/triton-adapter/include/Utils/Utils.h @@ -0,0 +1,201 @@ +#ifndef TRITONNPU_UTILS_UTILS_H +#define TRITONNPU_UTILS_UTILS_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" + +#include +#include + +namespace mlir { + +namespace ConverterUtils { + +const std::string GeneratedByMakeTensorPtrTAG = "GeneratedByMakeTensorPtr"; +const std::string discreteMaskAttrName = "DiscreteMask"; +const std::string discreteAttrName = "DiscreteMemAccess"; + +bool isaPermutedMemRefType(MemRefType); + +std::optional getLastStrideOfReinterpretCastOp(memref::ReinterpretCastOp op); + +Value getTransposedValue(Value source, const Location loc, + ConversionPatternRewriter &rewriter, + llvm::ArrayRef order); + +SmallVector getNParallelLoopsAttrs(unsigned n); + +Value getScalarValue(Value operand, Location loc, + ConversionPatternRewriter &rewriter); + +memref::SubViewOp makeSubViewOp(Value src, + const llvm::SmallVector &sizes, + const Location &loc, + ConversionPatternRewriter &rewriter); + +tensor::ExtractSliceOp makeExtractSliceOp(Value src, + const llvm::SmallVector &sizes, + const Location &loc, + ConversionPatternRewriter &rewriter); + +std::optional getFullShapeOp(Value val, + ConversionPatternRewriter &rewriter); + +SmallVector +getBoundarySizes(llvm::ArrayRef boundaryCheck, Value ptr, + const Location &loc, ConversionPatternRewriter &rewriter); + +SmallVector getBroadcastDims(RankedTensorType src, + RankedTensorType dst); + +SmallVector getUnbroadcastDims(RankedTensorType src, + RankedTensorType dst); + +} // namespace ConverterUtils + +class ConversionPatternRewriter; + +namespace triton { + +enum class IndirectLoadInterfaceOpType { Undefined = 0, Load = 1, Calc = 2 }; + +// Traceback from rootOp to find the targetOp with the specified condition +mlir::Operation * +findFirstMatchingOperandDef(mlir::Operation *rootOp, + const std::function &condFn); + +void traverseBackwardUpdateOperandChainIf( + Operation *op, std::function conditionFn, + std::function stopFn, + std::function actionFn, OpBuilder &builder, + DenseSet &handledOperation); + +void traverseBackwardUpdateOperandChainIf( + Operation *rootOp, std::function conditionFn, + std::function stopFn, + std::function actionFn); + +void traverseForwardUpdateUserChainIf( + Operation *op, std::function conditionFn, + std::function stopFn, + std::function actionFn, OpBuilder &builder, + llvm::SmallPtrSet &stopOps); + +void traverseForwardUpdateUserChainIf( + Operation *rootOp, std::function conditionFn, + std::function stopFn, + std::function actionFn, + llvm::SmallPtrSet &stopOps); + +// UseAnalysis will tag operations whose results are used only as meta-data +// with "MetaUse" tag. +bool isMetaUse(Operation *op); + +bool isMixUse(Operation *op); + +IndirectLoadInterfaceOpType getIndirectLoadInterfaceOpType(Operation *op); + +bool opIsIndirectLoad(Operation *op); + +bool opIsIndirectCalc(Operation *op); + +/// Maximum expected rank for loop tiling in tensor operations. +static constexpr int kMaxTiledRank = 4; + +/// This function generates a series of `scf.for` loops for the given dimensions +/// in `loopDims`. Although the loops are created sequentially, nesting is +/// simulated by adjusting the insertion point to the body of the last created +/// loop. This allows the `bodyFunc` to be inserted into the innermost scope. +/// +/// \param rewriter The MLIR OpBuilder used to create operations. +/// \param loc The source location information for debuggability. +/// \param target The memref value whose dimensions are being looped over. +/// \param loopDims An array of dimension indices to create loops for. +/// \param bodyFunc A callable that defines the operations to insert in the +/// innermost loop. +/// It takes a SmallVector of induction variables (one per +/// loop). +/// +template +void createSimpleNestedLoops(OpBuilder &rewriter, Location loc, Value target, + ArrayRef loopDims, Func bodyFunc) { + MemRefType type = cast(target.getType()); + int rank = type.getRank(); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + llvm::SmallVector loops; + llvm::SmallVector ivs; + + for (int dim : loopDims) { + Value ub; + if (type.isDynamicDim(dim)) { + ub = rewriter.create(loc, target, dim).getResult(); + } else { + ub = rewriter.create(loc, type.getDimSize(dim)); + } + + auto forOp = rewriter.create(loc, zero, ub, one); + rewriter.setInsertionPointToStart(forOp.getBody()); + loops.push_back(forOp); + ivs.push_back(forOp.getInductionVar()); + } + + bodyFunc(ivs); + + if (!loops.empty()) { + rewriter.setInsertionPointAfter(loops.front()); + } +} + +scf::ForOp createNestedLoops( + OpBuilder &builder, Location loc, unsigned currentDim, unsigned totalDims, + ValueRange LBs, ValueRange UBs, ValueRange steps, SmallVector &ivs, + ValueRange initArgs, + function_ref &, ValueRange)> + bodyBuilder); + +ModuleOp getModuleOpFromOperation(Operation *op); + +} // namespace triton + +class OpBuilder; + +OpFoldResult addOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult subOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult divOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult remOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult minOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +OpFoldResult maxOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b); + +LogicalResult +addReduceWithIndexAttrIfNeeded(ConversionPatternRewriter &rewriter, + linalg::ReduceOp reduceOp); + +OpFoldResult getOpFoldResultOfLayoutInfo(Value value, OpBuilder &builder); + +} // namespace mlir + +#endif // TRITONNPU_UTILS_UTILS_H \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/CMakeLists.txt new file mode 100644 index 000000000..a0a430bd4 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/CMakeLists.txt @@ -0,0 +1,6 @@ +add_subdirectory(TritonToAnnotation) +add_subdirectory(TritonToHIVM) +add_subdirectory(TritonToLinalg) +add_subdirectory(DiscreteMaskAccessConversion) +add_subdirectory(TritonToUnstructure) +add_subdirectory(Utils) diff --git a/third_party/ascend/triton-adapter/lib/DiscreteMaskAccessConversion/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/DiscreteMaskAccessConversion/CMakeLists.txt new file mode 100644 index 000000000..0c7f14568 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/DiscreteMaskAccessConversion/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(DiscreteMaskAccessConversion + DiscreteMaskAccessConversionPass.cpp + + DEPENDS + DiscreteMaskAccessConversionPassIncGen + + LINK_LIBS + BiShengIRHIVMDialect + MLIRIR + MLIRPass + MLIRTransforms + MLIRSupport + TritonIR +) diff --git a/third_party/ascend/triton-adapter/lib/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.cpp b/third_party/ascend/triton-adapter/lib/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.cpp new file mode 100644 index 000000000..72bea5aa2 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.cpp @@ -0,0 +1,163 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#include "DiscreteMaskAccessConversion/Passes.h" +#include "Utils/Utils.h" + +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" +#include "TritonToLinalg/MaskAnalysis.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DISCRETEMASKACCESSCONVERSION +#include "ascend/triton-adapter/include/DiscreteMaskAccessConversion/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace hivm; + +struct DiscreteMaskStoreConversion + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +LogicalResult matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const final { + auto mask = op.getMask(); + auto loc = op.getLoc(); + auto dst = op.getPtr(); + auto src = op.getValue(); + + if (!mask) + return failure(); + + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + if (!isContMask.failed()) { + mstate.eraseInsertedOps(op, rewriter); + return failure(); + } + + auto loadFromDstOp = rewriter.create( + loc, dst, op.getCache(), op.getEvict(), false); + + auto selOp = rewriter.create(loc, mask, src, loadFromDstOp.getResult()); + auto newStore = rewriter.create( + loc, dst, selOp, op.getCache(), op.getEvict()); + newStore->setAttr(ConverterUtils::discreteMaskAttrName, UnitAttr::get(rewriter.getContext())); + rewriter.replaceOp(op, newStore); + return success(); +} +}; + +struct DiscreteMaskLoadConversion + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +LogicalResult matchAndRewrite(triton::LoadOp op, + PatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + auto other = op.getOther(); + auto mask = op.getMask(); + auto ptr = op.getPtr(); + + if (!mask) + return failure(); + + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + if (!isContMask.failed()) { + mstate.eraseInsertedOps(op, rewriter); + return failure(); + } + + if (!other) { + auto ptrType = ptr.getType(); + auto elementType = getElementTypeOrSelf(ptrType); + if (auto intType = dyn_cast(ptrType)) { + other = rewriter.create( + loc, elementType, rewriter.getIntegerAttr(elementType, 0)); + } else if (auto floatType = dyn_cast(ptrType)) { + other = rewriter.create( + loc, elementType, rewriter.getFloatAttr(elementType, 0.0)); + } else { + llvm_unreachable("Unsupported type for constant creation"); + } + } + + auto newLoadOp = rewriter.create( + loc, ptr, op.getCache(), op.getEvict(), op.getIsVolatile()); + auto discreteMaskOp = rewriter.create(loc, mask, newLoadOp, other); + rewriter.replaceOp(op, discreteMaskOp); + return success(); +} +}; + +struct DiscreteMaskAtomicAddConversion : OpRewritePattern { +using OpRewritePattern::OpRewritePattern; +LogicalResult matchAndRewrite(triton::AtomicRMWOp op, PatternRewriter &rewriter) const final { + if (op.getAtomicRmwOp() != triton::RMWOp::FADD && op.getAtomicRmwOp() != triton::RMWOp::ADD) { + return failure(); + } + auto loc = op.getLoc(); + auto ptr = op.getPtr(); + auto value = op.getVal(); + auto mask = op.getMask(); + + if (!mask) + return failure(); + + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + if (!isContMask.failed()) { + mstate.eraseInsertedOps(op, rewriter); + return failure(); + } + + mlir::Value zeros; + auto valueType = value.getType(); + if (auto tensorType = mlir::dyn_cast(valueType)) { + auto elemType = tensorType.getElementType(); + auto zeroAttr = rewriter.getZeroAttr(elemType); + auto denseAttr = mlir::DenseElementsAttr::get(tensorType, zeroAttr); + zeros = rewriter.create(loc, denseAttr); + } else if (mlir::isa(valueType) || mlir::isa(valueType)) { + auto zeroAttr = rewriter.getZeroAttr(valueType); + zeros = rewriter.create(loc, zeroAttr); + } else { + op.emitError() << "Unsupported value type for select: " << valueType << "\n"; + return failure(); + } + auto maskedValue = rewriter.create(loc, mask, value, zeros); + auto newAtomicAddOp = rewriter.create( + loc, value.getType(), op.getAtomicRmwOp(), ptr, maskedValue, mlir::Value(), op.getSem(), op.getScope()); + rewriter.replaceOp(op, newAtomicAddOp); + return success(); +} +}; + +void DiscreteMaskAccessConversionPass::runOnOperation() { + auto moduleOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) { + moduleOp->emitError("failed to apply discrete mask access patterns"); + signalPassFailure(); + } +} + +std::unique_ptr> mlir::triton::createDiscreteMaskAccessConversionPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToAnnotation/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/TritonToAnnotation/CMakeLists.txt new file mode 100644 index 000000000..4492c6f8b --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToAnnotation/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(TritonToAnnotation + TritonToAnnotation.cpp + + DEPENDS + TritonToAnnotationConversionPassIncGen + + LINK_LIBS + BiShengIRAnnotationDialect + MLIRIR + MLIRPass + MLIRTransforms + MLIRSupport + TritonIR +) diff --git a/third_party/ascend/triton-adapter/lib/TritonToAnnotation/TritonToAnnotation.cpp b/third_party/ascend/triton-adapter/lib/TritonToAnnotation/TritonToAnnotation.cpp new file mode 100644 index 000000000..a97fc83a8 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToAnnotation/TritonToAnnotation.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#include "TritonToAnnotation/Passes.h" + +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_TRITONTOANNOTATION +#include "ascend/triton-adapter/include/TritonToAnnotation/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; + +namespace { +struct TritonToAnnotationPass + : public mlir::triton::impl::TritonToAnnotationBase< + TritonToAnnotationPass> { + void runOnOperation() override; +}; +} // namespace + +struct TritonAnnotationConversionPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::AnnotationOp op, + PatternRewriter &rewriter) const final { + auto markOp = rewriter.create(op.getLoc(), op.getSrc()); + // Forward all annotations. + markOp->setAttrs(op->getAttrs()); + rewriter.eraseOp(op); + return success(); + } +}; + +void TritonToAnnotationPass::runOnOperation() { + auto module = getOperation(); + ConversionTarget target(getContext()); + target.addLegalDialect(); + + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext()); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr> +mlir::triton::createTritonToAnnotationPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToHIVM/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/TritonToHIVM/CMakeLists.txt new file mode 100644 index 000000000..52a3d532d --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToHIVM/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(TritonToHIVM + TritonToHIVM.cpp + + DEPENDS + TritonToHIVMConversionPassIncGen + + LINK_LIBS + BiShengIRHIVMDialect + MLIRIR + MLIRPass + MLIRTransforms + MLIRSupport + TritonIR +) diff --git a/third_party/ascend/triton-adapter/lib/TritonToHIVM/TritonToHIVM.cpp b/third_party/ascend/triton-adapter/lib/TritonToHIVM/TritonToHIVM.cpp new file mode 100644 index 000000000..67969e20c --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToHIVM/TritonToHIVM.cpp @@ -0,0 +1,162 @@ +/* + * Copyright (c) Huawei Technologies Co. + * Licensed under the MIT license. + */ + +#include "TritonToHIVM/Passes.h" + +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_TRITONTOHIVM +#include "ascend/triton-adapter/include/TritonToHIVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace hivm; + +namespace { + +struct CoreAndPipes { + TCoreTypeAttr core; + PipeAttr producer; + PipeAttr consumer; +}; + +static LogicalResult EmitUnknownOpError(Operation *op, + llvm::StringRef opName) { + op->emitError("Unknown custom operation: ") << opName; + return failure(); +} + +static void CreateSyncBlock(PatternRewriter &rewriter, Location loc, + MLIRContext *ctx, Operation *op, int64_t id, + hivm::SyncBlockMode mode, + PipeAttr pipe1, PipeAttr pipe2) { + auto syncMode = hivm::SyncBlockModeAttr::get(ctx, mode); + auto newOp = rewriter.create( + loc, syncMode, rewriter.getI16IntegerAttr(id), + Value{}, pipe1, pipe2); + rewriter.replaceOp(op, newOp); +} + +static CoreAndPipes GetCoreAndPipes(MLIRContext *ctx, + llvm::StringRef opName, + llvm::StringRef sender) { + // Step 1: Decide pipes + PipeAttr producer; + PipeAttr consumer = PipeAttr::get(ctx, PIPE::PIPE_MTE2); + + if (sender == "cube") { + producer = PipeAttr::get(ctx, PIPE::PIPE_FIX); + } else { + producer = PipeAttr::get(ctx, PIPE::PIPE_MTE3); + } + + // Step 2: Decide core type + TCoreTypeAttr core; + if (sender == "cube") { + if (opName == "sync_block_set") + core = TCoreTypeAttr::get(ctx, TCoreType::CUBE); + else + core = TCoreTypeAttr::get(ctx, TCoreType::VECTOR); + } else { + if (opName == "sync_block_set") + core = TCoreTypeAttr::get(ctx, TCoreType::VECTOR); + else + core = TCoreTypeAttr::get(ctx, TCoreType::CUBE); + } + + return {core, producer, consumer}; +} + +} // end anonymous namespace +namespace { +struct TritonToHIVMPass + : public mlir::triton::impl::TritonToHIVMBase< + TritonToHIVMPass> { + void runOnOperation() override; +}; +} // namespace + +struct TritonCustomOpToHIVMSyncOpConversion + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +LogicalResult matchAndRewrite(triton::CustomOp op, + PatternRewriter &rewriter) const final { + auto *ctx = op->getContext(); + auto loc = op->getLoc(); + auto args = op.getStrArgs(); + auto argAttr = dyn_cast(args[0]); + auto id = dyn_cast(args[1]).getInt(); + llvm::StringRef opName = op.getOpName(); + llvm::StringRef arg = argAttr.getValue(); + + if (opName == "sync_block_all") { + if (arg == "all_cube") { + CreateSyncBlock(rewriter, loc, ctx, op, id, + hivm::SyncBlockMode::ALL_CUBE, + PipeAttr::get(ctx, PIPE::PIPE_FIX), + hivm::PipeAttr{}); + } else if (arg == "all_vector") { + CreateSyncBlock(rewriter, loc, ctx, op, id, + hivm::SyncBlockMode::ALL_VECTOR, + hivm::PipeAttr{}, + PipeAttr::get(ctx, PIPE::PIPE_MTE3)); + } else if (arg == "all") { + CreateSyncBlock(rewriter, loc, ctx, op, id, + hivm::SyncBlockMode::ALL, + PipeAttr::get(ctx, PIPE::PIPE_FIX), + PipeAttr::get(ctx, PIPE::PIPE_MTE3)); + } else { + return EmitUnknownOpError(op, opName); + } + return success(); + } + + if (opName == "sync_block_set") { + auto [coreAttr, prodPipe, consPipe] = GetCoreAndPipes(ctx, opName, arg); + rewriter.replaceOp(op, rewriter.create( + loc, coreAttr, prodPipe, consPipe, + rewriter.getIndexAttr(id))); + return success(); + } + + if (opName == "sync_block_wait") { + auto [coreAttr, prodPipe, consPipe] = GetCoreAndPipes(ctx, opName, arg); + rewriter.replaceOp(op, rewriter.create( + loc, coreAttr, prodPipe, consPipe, + rewriter.getIndexAttr(id))); + return success(); + } + + return EmitUnknownOpError(op, opName); +} +}; + + +void TritonToHIVMPass::runOnOperation() { + auto module = getOperation(); + ConversionTarget target(getContext()); + target.addLegalDialect(); + + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext()); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr> mlir::triton::createTritonToHIVMPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp new file mode 100644 index 000000000..462bedcf1 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/ArgMinMaxConverter.cpp @@ -0,0 +1,79 @@ +#include "TritonToLinalg/ArgMinMaxConverter.h" + +namespace TTOpConverters { +using namespace mlir; +using namespace triton; + +// ArgMinConverter functions +LogicalResult ArgMinConverter::matchComparisonResult( + Value currValue, Value currIndex, Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, Value &comparisonResult) { + LLVM_DEBUG(llvm::dbgs() << "Matching: " << *it << "\n"); + + auto cmpOp = dyn_cast(*it); + auto cmpIOp = dyn_cast(*it++); + if (!cmpOp && !cmpIOp) + return failure(); + + if (cmpOp) { + if (cmpOp.getPredicate() != arith::CmpFPredicate::OLT || + currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { + return failure(); + } + comparisonResult = cmpOp; + } + + if (cmpIOp) { + if ((cmpIOp.getPredicate() != arith::CmpIPredicate::slt && + cmpIOp.getPredicate() != arith::CmpIPredicate::ult) || + currValue != cmpIOp.getLhs() || reduceValue != cmpIOp.getRhs()) { + return failure(); + } + comparisonResult = cmpIOp; + } + + return success(); +} + +float ArgMinConverter::getBaseReductionValue() { + return std::numeric_limits::infinity(); +} + +int8_t ArgMinConverter::getBaseReductionIntValue() { return 127; } + +// ArgMaxConverter functions +LogicalResult ArgMaxConverter::matchComparisonResult( + Value currValue, Value currIndex, Value reduceValue, Value reduceIndex, + mlir::Block::iterator &it, Value &comparisonResult) { + auto cmpOp = dyn_cast(*it); + auto cmpIOp = dyn_cast(*it++); + if (!cmpOp && !cmpIOp) + return failure(); + + if (cmpOp) { + if (cmpOp.getPredicate() != arith::CmpFPredicate::OGT || + currValue != cmpOp.getLhs() || reduceValue != cmpOp.getRhs()) { + return failure(); + } + comparisonResult = cmpOp; + } + + if (cmpIOp) { + if ((cmpIOp.getPredicate() != arith::CmpIPredicate::sgt && + cmpIOp.getPredicate() != arith::CmpIPredicate::ugt) || + currValue != cmpIOp.getLhs() || reduceValue != cmpIOp.getRhs()) { + return failure(); + } + comparisonResult = cmpIOp; + } + + return success(); +} + +float ArgMaxConverter::getBaseReductionValue() { + return -std::numeric_limits::infinity(); +} + +int8_t ArgMaxConverter::getBaseReductionIntValue() { return -128; } + +} // namespace TTOpConverters \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp new file mode 100644 index 000000000..59f60653c --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp @@ -0,0 +1,1820 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "TritonToLinalg/TritonToLinalgPass.h" +#include "Utils/Utils.h" + +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include + +#define DEBUG_TYPE "triton-block-ptr-analysis" +namespace mlir { +namespace triton { + +// MemAccType selectMaxMemAccTy(const MemAccType &v1, const MemAccType &v2) { +// return (v1 > v2) ? v1 : v2; +// } + +SmallVector &BlockData::getOffsetsRef() { return this->offsets; } + +SmallVector &BlockData::getSizesRef() { return this->sizes; } + +SmallVector &BlockData::getStridesRef() { return this->strides; } + +Value &BlockData::getSourceRef() { return this->source; } + +OpFoldResult &BlockData::getScalarRef() { return this->scalar; } + +SmallVector BlockData::getOffsets() const { + return this->offsets; +} + +SmallVector BlockData::getSizes() const { return this->sizes; } + +SmallVector BlockData::getStrides() const { + return this->strides; +} + +OpFoldResult BlockData::getOffset(int index) const { + return this->offsets[index]; +} + +OpFoldResult BlockData::getSize(int index) const { return this->sizes[index]; } + +OpFoldResult BlockData::getStride(int index) const { + return this->strides[index]; +} + +OpFoldResult BlockData::getScalar() const { return this->scalar; } + +Value BlockData::getSource() const { return this->source; } + +MemAccType BlockData::getMemAccType() const { return this->memAccTy; }; + +MemAccType &BlockData::getMemAccTypeRef() { return this->memAccTy; }; + +bool BlockData::isScalar() const { return !(this->scalar).isNull(); } + +bool BlockData::isEmpty() const { + return !(this->getRank() || this->source || !(this->scalar).isNull()); +} + +bool BlockData::hasSource() const { return this->source != nullptr; } + +void BlockData::removeSource() { this->source = nullptr; }; + +bool BlockData::hasResElemTy() const { return this->resElemTy != nullptr; } + +Type &BlockData::getResElemTyRef() { return this->resElemTy; } + +Type BlockData::getResElemTy() const { return this->resElemTy; } + +int64_t BlockData::getRank() const { + assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); + return this->offsets.size(); +} + +void BlockData::setResElemTy(const Type &Ty) { this->resElemTy = Ty; } + +void BlockData::setScalar(const OpFoldResult &scalar) { this->scalar = scalar; } + +void BlockData::setSource(const Value &src) { this->source = src; } + +void BlockData::setOffsets(const SmallVector &offsets) { + this->offsets = offsets; +} + +void BlockData::setStrides(const SmallVector &strides) { + this->strides = strides; +} + +void BlockData::setSizes(const SmallVector &szs) { + this->sizes = szs; +} + +void BlockData::setMemAccTy(const MemAccType &v) { this->memAccTy = v; } + +void BlockData::setMemAccVal(const MemAccVal v) { this->memAccTy.value = v; } + +OpFoldResult BlockData::inferBlockOffset(const Location &loc, + OpBuilder &builder) const { + OpFoldResult retOffset = builder.getIndexAttr(0); + for (auto ofr : offsets) { + retOffset = addOpFoldResult(retOffset, ofr, loc, builder); + } + return retOffset; +} + +MemRefType BlockData::getResultMemrefType(int64_t offset, + ArrayRef resultShape) const { + SmallVector staticStrides; + SmallVector dynamicStrides; + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + + auto baseMemrefType = dyn_cast(this->source.getType()); + assert(baseMemrefType && "Invalid element type. It should be a base memref type."); + auto elementType = baseMemrefType.getElementType(); + auto layout = + StridedLayoutAttr::get(this->source.getContext(), offset, staticStrides); + return MemRefType::get(resultShape, elementType, layout); +} + +void BlockData::addBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter) { + assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); + // When both left block and right block have source, it is indirect load. + assert(!(lBlock.hasSource() && rBlock.hasSource()) && + "Don't support each BlockData has own base source pointer"); + this->source = + lBlock.hasSource() ? lBlock.getSourceRef() : rBlock.getSourceRef(); + + assert(!(lBlock.hasResElemTy() && rBlock.hasResElemTy())); + if (lBlock.hasResElemTy()) { + assert(lBlock.hasSource()); + this->resElemTy = lBlock.getResElemTyRef(); + } else if (rBlock.hasResElemTy()) { + assert(rBlock.hasSource()); + this->resElemTy = rBlock.getResElemTyRef(); + } + + // Acctually `scalar` should be accumulated into `offset` and `stride` finally + // In addBlock, just pass `scalar` when: + // 1. both lhs and rhs have `scalar` + // 2. otherwise, both lhs and rhs are scalar type with rank 0 + // Except above, original `scalar` has been fused into `offset` under add. + if (lBlock.isScalar() && rBlock.isScalar()) { + auto addScalar = addOpFoldResult(lBlock.getScalarRef(), + rBlock.getScalarRef(), loc, rewriter); + this->scalar = addScalar; + } else if (lBlock.getRank() == 0) { + // When both lhs and rhs are scalar type with rank 0, just try passing + // potential `scalar` + this->scalar = + lBlock.isScalar() ? lBlock.getScalarRef() : rBlock.getScalarRef(); + } + + for (const auto &[lOffset, rOffset] : + llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { + this->offsets.push_back(addOpFoldResult(lOffset, rOffset, loc, rewriter)); + } + + for (const auto &[lStride, rStride] : + llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { + this->strides.push_back(addOpFoldResult(lStride, rStride, loc, rewriter)); + } + + // Both sizes are same implicitly under `add` + this->sizes = lBlock.getSizesRef(); + + this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); + this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); + // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), + // rBlock.getMemAccType())); +} + +void BlockData::mulBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter) { + assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); + + assert(!(lBlock.hasSource() && rBlock.hasSource())); + + assert( + (lBlock.isScalar() ^ rBlock.isScalar()) && + "Currently only support one and only one scalar in function mulBlock()"); + + BlockData *lb = &lBlock; + BlockData *rb = &rBlock; + if (lb->isScalar()) { + std::swap(lb, rb); + } + + // In mulBlock, `scalar` will be accumulated into `offset` and `stride` + OpFoldResult rScalar = rb->getScalarRef(); + for (const auto &lOffset : lb->getOffsetsRef()) { + this->offsets.push_back(mulOpFoldResult(lOffset, rScalar, loc, rewriter)); + } + + for (const auto &lStride : lb->getStridesRef()) { + this->strides.push_back(mulOpFoldResult(lStride, rScalar, loc, rewriter)); + } + + this->sizes = lb->getSizesRef(); + + this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); + this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); + // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), + // rBlock.getMemAccType())); +} + +void BlockData::divBlock(BlockData &lBlock, BlockData &rBlock, Location loc, + ConversionPatternRewriter &rewriter) { + assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); + + assert(!(lBlock.hasSource() && rBlock.hasSource())); + + for (const auto &[lOffset, rOffset] : + llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { + this->offsets.push_back(divOpFoldResult(lOffset, rOffset, loc, rewriter)); + } + + for (const auto &[lStride, rStride] : + llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { + this->strides.push_back(divOpFoldResult(lStride, rStride, loc, rewriter)); + } + + this->sizes = lBlock.getSizesRef(); + + this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); + this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); + // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), + // rBlock.getMemAccType())); +} + +memref::ReinterpretCastOp BlockData::createCastOp(ArrayRef resultShape, + const Location &loc, + OpBuilder &builder) const { + OpFoldResult resOffset = this->inferBlockOffset(loc, builder); + auto resultType = this->getResultMemrefType( + isa(resOffset) ? getConstantIntValue(resOffset).value() + : ShapedType::kDynamic, + resultShape); + + return builder.create( + loc, resultType, this->source, resOffset, this->sizes, this->strides); +} + +void BlockData::dump() const { + llvm::outs() << "[INFO][BEG] BlockData info\n"; + llvm::outs() << "offsets has " << offsets.size() << " items\n"; + int cnt = 0; + for (auto it = offsets.begin(); it != offsets.end(); ++it) { + llvm::outs() << "offsets[" << cnt++ << "] = " << *it << "\n"; + } + llvm::outs() << "sizes has " << sizes.size() << " items\n"; + cnt = 0; + for (auto it = sizes.begin(); it != sizes.end(); ++it) { + llvm::outs() << "sizes[" << cnt++ << "] = " << *it << "\n"; + } + llvm::outs() << "strides has " << strides.size() << " items\n"; + cnt = 0; + for (auto it = strides.begin(); it != strides.end(); ++it) { + llvm::outs() << "strides[" << cnt++ << "] = " << *it << "\n"; + } + llvm::outs() << "source = " << source << "\n"; + llvm::outs() << "scalar = " << scalar << "\n"; + llvm::outs() << "resElemTy = " << resElemTy << "\n"; + llvm::outs() << "memAccTy = " << memAccTy.toString() << "\n"; + llvm::outs() << "[INFO][END] BlockData info\n"; +} + +Value BlockDataParser::getScalarMemRef(Value ptr, Value memref, + const Location &loc, + ConversionPatternRewriter &rewriter) { + assert(isa(ptr.getType()) && "expect a scalar pointer"); + if (ptr.getDefiningOp()) { + if (auto castOp = memref.getDefiningOp()) { + return castOp.getResult(); + } else { + llvm_unreachable("pointer value is defined by an unexpected op"); + } + } + + assert(isa(ptr) && + "pointer should be produced by addptr or block argument"); + BlockData data; + data.setSource(memref); + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + data.getSizesRef().push_back(rewriter.getIndexAttr(1)); + data.getStridesRef().push_back(rewriter.getIndexAttr(1)); + auto castOp = data.createCastOp(SmallVector(1, 1), loc, rewriter); + return castOp.getResult(); +} + +void BlockDataParser::parse( + Value operand, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + if (known.find(operand) != known.end()) { + return data = known.lookup(operand), void(); + } + + if (isa(operand.getType())) { + data.setScalar(getOpFoldResultOfLayoutInfo(operand, rewriter)); + return; + } + + // + if (isa(operand.getType())) { + // Just consider two state: ptr and ptr> + auto remappedPtr = rewriter.getRemappedValue(operand); + assert(remappedPtr); + if (auto op = operand.getDefiningOp()) { + if (auto addPtrOp = dyn_cast(op)) { + parseAddPtr(addPtrOp, data, loc, rewriter, known); + } else if (auto bitcastOp = dyn_cast(op)) { + parseBitcast(bitcastOp, data, loc, rewriter, known); + } else if (auto makeTensorPtrOp = dyn_cast(op)) { + parseTensorPtr(makeTensorPtrOp, data, loc, rewriter, known); + } else if (auto advanceOp = dyn_cast(op)) { + // To support + // ptr_0 = tl.advance(ptr) + // ptr_1 = tl.advance(ptr_0) + parseTensorPtr(advanceOp, data, loc, rewriter, known); + } else if (auto intToPtrOp = dyn_cast(op)) { + data.setSource(remappedPtr); + } else { + LLVM_DEBUG({ llvm::dbgs() << operand << "\n"; }); + llvm_unreachable( + "Unexpected operand defining operation, a scalar " + "pointer can only be produced by AddPtrOp or direct block ptr"); + } + } else { + data.setSource(remappedPtr); + } + return; + } + + // not a scalar pointer + if (auto addOp = operand.getDefiningOp()) { + parseAdd(addOp, data, loc, rewriter, known); + } else if (auto mulOp = operand.getDefiningOp()) { + parseMul(mulOp, data, loc, rewriter, known); + } else if (auto addPtrOp = operand.getDefiningOp()) { + parseAddPtr(addPtrOp, data, loc, rewriter, known); + } else if (auto constOp = operand.getDefiningOp()) { + parseConstSplat(constOp, data, loc, rewriter, known); + } else if (auto broadcastOp = operand.getDefiningOp()) { + parseBroadcast(broadcastOp, data, loc, rewriter, known); + } else if (auto splatOp = operand.getDefiningOp()) { + parseSplat(splatOp, data, loc, rewriter, known); + } else if (auto expandDimsOp = + operand.getDefiningOp()) { + parseExpandDims(expandDimsOp, data, loc, rewriter, known); + } else if (auto remOp = operand.getDefiningOp()) { + parseRem(remOp, data, loc, rewriter, known); + } else if (auto bitcastOp = operand.getDefiningOp()) { + parseBitcast(bitcastOp, data, loc, rewriter, known); + } else if (auto extsiOp = operand.getDefiningOp()) { + parseExtSI(extsiOp, data, loc, rewriter, known); + } else if (auto divOp = operand.getDefiningOp()) { + parseDiv(divOp, data, loc, rewriter, known); + } else if (auto makeRangeOp = operand.getDefiningOp()) { + parseMakeRange(makeRangeOp, data, loc, rewriter, known); + } else if (auto reduceOp = operand.getDefiningOp()) { + parseReduce(reduceOp, data, loc, rewriter, known); + } else if (auto loadOp = operand.getDefiningOp()) { + parseIndirectLoad(loadOp, data, loc, rewriter, known); + } else if (auto castOp = operand.getDefiningOp()) { + parseIndirectLoad(castOp, data, loc, rewriter, known); + } else if (auto extractSliceOp = + operand.getDefiningOp()) { + parseExtractSlice(extractSliceOp, data, loc, rewriter, known); + } else if (auto forOp = operand.getDefiningOp()) { + parseIndirectLoad(forOp, data, loc, rewriter, known); + } else if (auto tensorCastOp = operand.getDefiningOp()) { + // Used for identity operation. + parse(tensorCastOp.getSource(), data, loc, rewriter, known); + } else { + operand.dump(); + llvm_unreachable("encountered AddPtrOp produced by unsupported operation"); + } +} + +void BlockDataParser::parseAdd( + arith::AddIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + BlockData lBlock, rBlock; + parse(op.getLhs(), lBlock, loc, rewriter, known); + parse(op.getRhs(), rBlock, loc, rewriter, known); + data.addBlock(lBlock, rBlock, loc, rewriter); +} + +void BlockDataParser::parseMul( + arith::MulIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + BlockData lBlock, rBlock; + parse(op.getLhs(), lBlock, loc, rewriter, known); + parse(op.getRhs(), rBlock, loc, rewriter, known); + + data.mulBlock(lBlock, rBlock, loc, rewriter); +} + +void BlockDataParser::parseDiv( + arith::DivSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + BlockData lBlock, rBlock; + parse(op.getLhs(), lBlock, loc, rewriter, known); + parse(op.getRhs(), rBlock, loc, rewriter, known); + data.divBlock(lBlock, rBlock, loc, rewriter); +} + +// TODO : support modulos +void BlockDataParser::parseRem( + arith::RemSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(false && "Address expression with modulo is not supported yet, it " + "shall be analysis at linearize."); +} + +void BlockDataParser::parseMakeRange( + triton::MakeRangeOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + auto shape = dyn_cast(op.getType()).getShape(); + + auto start = op.getStart(); + auto end = op.getEnd(); + auto stride = (end >= start) && (end - start <= shape[0]); + assert(stride == 1 && + "make_range op should always return a tensor of stride 1"); + + data.getOffsetsRef().push_back(rewriter.getIndexAttr(start)); + data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); + data.getStridesRef().push_back(rewriter.getIndexAttr(stride)); +} + +void BlockDataParser::parseExpandDims( + triton::ExpandDimsOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + parse(op.getSrcMutable().get(), data, loc, rewriter, known); + auto resShape = dyn_cast(op.getResult().getType()).getShape(); + auto axis = op.getAxis(); + + assert(resShape[axis] == 1 && + "The destiny shape of changed dimension should be 1"); + + data.getOffsetsRef().insert(data.getOffsetsRef().begin() + axis, + rewriter.getIndexAttr(0)); + data.getSizesRef().insert(data.getSizesRef().begin() + axis, + rewriter.getIndexAttr(1)); + data.getStridesRef().insert(data.getStridesRef().begin() + axis, + rewriter.getIndexAttr(0)); +} + +void BlockDataParser::parseExtractSlice( + tensor::ExtractSliceOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + const std::string scenarioMessages = + "PtsAnalysis supports indirectly block load in the " + "following scenario\n" + "B = tl.load(Aptr + Aoffset) # B is 1D tensor\n" + "s = tl.extract_slice(indices, offsets= (i,), sizes= " + "(1,), strides= (1,)) # s is a tensor<1x$dtype>\n" + "D = tl.load(Cptr + s + Coffset) # s is used as the " + "scalar offset\n"; // tensor<2x$dtype> will be support + // soon + + auto extract_src = op->getOperand(0); + BlockData srcBlock; + parse(extract_src, srcBlock, loc, rewriter, known); + if (!srcBlock.hasSource()) { + llvm_unreachable(scenarioMessages.c_str()); + } + if (!isa(srcBlock.getSource().getDefiningOp())) { + llvm_unreachable(scenarioMessages.c_str()); + } + + auto extract_result = op->getResult(0); + auto shaped_ty = dyn_cast(extract_result.getType()); + auto shape = shaped_ty.getShape(); + if (shape.size() > 1 || shape[0] > 1) { + llvm_unreachable(scenarioMessages.c_str()); + } + auto castOp = rewriter.create( + loc, RankedTensorType::get(shape, rewriter.getIndexType()), + extract_result); + auto offset = castOp.getResult(); + if (data.isEmpty()) { + data.getOffsetsRef().push_back(offset); + data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); + data.getStridesRef().push_back(rewriter.getIndexAttr(1)); + } else { + llvm_unreachable( + "parseExtractSlice with offset already setup not yet supported"); + } +} + +void BlockDataParser::parseBitcast( + triton::BitcastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + parse(op.getSrc(), data, loc, rewriter, known); + + auto resType = op.getResult().getType(); + Type resElemPointeeTy = nullptr; + if (auto resShapedTy = dyn_cast(resType)) { + auto resElemTy = resShapedTy.getElementType(); + resElemPointeeTy = + dyn_cast(resElemTy).getPointeeType(); + } else { + auto srcPointeeType = + cast(op.getSrc().getType()).getPointeeType(); + auto resPointeeType = cast(resType).getPointeeType(); + + // Handling special case + // If Op is MetaUse or src is i1 block argument and dst is i8, + // it should be converted to UnrealizedConversionCast + if (op->hasAttr("MetaUse") || + (isa(op.getSrc()) && + srcPointeeType == rewriter.getIntegerType(1) && + resPointeeType == rewriter.getIntegerType(8))) { + resElemPointeeTy = resPointeeType; + } else { + auto remappedValue = rewriter.getRemappedValue(op); + data.setSource(remappedValue); + LLVM_DEBUG({ + llvm::dbgs() << "Remapping bitcastOp:\n"; + llvm::dbgs() << op << "\nto \n"; + llvm::dbgs() << remappedValue << "\n"; + }); + } + } + data.setResElemTy(resElemPointeeTy); +} + +void BlockDataParser::parseExtSI( + arith::ExtSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + parse(op.getIn(), data, loc, rewriter, known); +} + +void BlockDataParser::parseBroadcast( + triton::BroadcastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + auto src = op.getSrcMutable().get(); + auto dst = op.getResult(); + assert(isa(src.getType()) && + "tt.broadcast's input should be a tensor"); + + auto srcShape = dyn_cast(src.getType()).getShape(); + auto dstShape = dyn_cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source shoule be equal to destnation"); + + parse(src, data, loc, rewriter, known); + + for (const auto &[idx, src_dst] : + llvm::enumerate(llvm::zip(srcShape, dstShape))) { + const auto &[srcAxis, dstAxis] = src_dst; + if (srcAxis == dstAxis) { + continue; + } + assert(srcAxis < dstAxis && + "srcShape of broadcastOp must be less than dstShape."); + data.getSizesRef()[idx] = rewriter.getIndexAttr(dstAxis); + } +} + +void BlockDataParser::parseSplat( + triton::SplatOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + auto src = op.getSrc(); + auto dst = op.getResult(); + auto dstShape = dyn_cast(dst.getType()).getShape(); + + parse(src, data, loc, rewriter, known); + + if (isa(src.getType()) || + isa(src.getType())) { + if (!data.isEmpty()) { + data.getOffsetsRef().clear(); + data.getSizesRef().clear(); + data.getStridesRef().clear(); + } + for (auto dstAxis : dstShape) { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + data.getSizesRef().push_back(rewriter.getIndexAttr(dstAxis)); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } + } else { + op->emitError("Block data Analysis: unsupported splat pattern"); + return; + } + if (data.isScalar()) { + data.getOffsetsRef()[0] = data.getScalarRef(); + } +} + +void BlockDataParser::parseConstSplat( + arith::ConstantOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + DenseElementsAttr denseAttr = dyn_cast(op.getValue()); + assert(denseAttr && denseAttr.isSplat() && + isa(denseAttr.getElementType())); + + auto innerVal = denseAttr.getValues()[0].getValue(); + auto innerValIndexAttr = rewriter.getIndexAttr(innerVal.getSExtValue()); + + // for mul state + data.setScalar(innerValIndexAttr); + + auto resType = dyn_cast(op.getResult().getType()); + size_t loopLimit = resType.getShape().size(); + for (auto i = 0; i < loopLimit; i++) { + // Add original dense val to first dim offset for add state + if (i == 0) { + data.getOffsetsRef().push_back(innerValIndexAttr); + } else { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + } + data.getSizesRef().push_back(rewriter.getIndexAttr(resType.getShape()[i])); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } +} + +template +std::enable_if_t || + std::is_same_v> +BlockDataParser::parseTensorPtr( + T op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + Value remappedValue = rewriter.getRemappedValue(op); + if (auto castOp = remappedValue.getDefiningOp()) { + parseReinterpretCast(castOp, data, loc, rewriter, known); + } else { + llvm_unreachable("the value should be mapped to memref.reinterpret_cast"); + } +} + +void BlockDataParser::parseAddPtr( + triton::AddPtrOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + BlockData ptrBlock, offsetBlock; + parse(op.getPtr(), ptrBlock, op.getLoc(), rewriter, known); + parse(op.getOffset(), offsetBlock, op.getLoc(), rewriter, known); + + assert(ptrBlock.hasSource() && + "Ptr field should provide source/base pointer"); + // offset has source means offset is from tl.load and other ops(TODO) + if (offsetBlock.hasSource()) { + ptrBlock.setMemAccTy(offsetBlock.getMemAccType()); + offsetBlock.removeSource(); + } + + // handle for loop & scalar + if (ptrBlock.getRank() == 1 && offsetBlock.getRank() == 0) { + offsetBlock.getSizesRef().push_back(rewriter.getIndexAttr(1)); + offsetBlock.getOffsetsRef().push_back(offsetBlock.getScalarRef()); + offsetBlock.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } + + assert(ptrBlock.getRank() == offsetBlock.getRank() && + "ptr and offset should have same rank"); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "[parseAddPtr][BEG] =========================\n"; + os << "[parseAddPtr] op is " << op << "\n"; + for (int i = 0; i < ptrBlock.getRank(); i++) { + os << "ptrBlock.getOffsetsRef()[" << i + << "] = " << ptrBlock.getOffsetsRef()[i] << "\n"; + os << "ptrBlock.getSizesRef()[" << i + << "] = " << ptrBlock.getSizesRef()[i] << "\n"; + os << "ptrBlock.getStridesRef()[" << i + << "] = " << ptrBlock.getStridesRef()[i] << "\n"; + os << "offsetBlock.getOffsetsRef()[" << i + << "] = " << offsetBlock.getOffsetsRef()[i] << "\n"; + os << "offsetBlock.getSizesRef()[" << i + << "] = " << offsetBlock.getSizesRef()[i] << "\n"; + os << "offsetBlock.getStridesRef()[" << i + << "] = " << offsetBlock.getStridesRef()[i] << "\n"; + } + os << "[parseAddPtr][END] -------------------------\n"; + }); + data.addBlock(ptrBlock, offsetBlock, op.getLoc(), rewriter); +} + +void BlockDataParser::parseReinterpretCast( + memref::ReinterpretCastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + assert(data.isEmpty()); + + data.setOffsets(op.getMixedOffsets()); + data.setSizes(op.getMixedSizes()); + data.setStrides(op.getMixedStrides()); + data.setSource(op.getSource()); + + // In memref::ReinterpretCastOp, offset means the total of collapsing multiple + // dimensions, which corresponds to first dim offset in block data. + // Here populate the rest of the dimensions with zeroes. + assert(data.getOffsetsRef().size() == 1); + size_t loopLimit = data.getSizesRef().size(); + for (size_t i = 1; i < loopLimit; i++) { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + } +} + +void BlockDataParser::parseReduce( + triton::ReduceOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + + const std::string scenarioMessages = + "PtsAnalysis supports indirectly block load in the following scenario\n" + "B = tl.load(Aptr + Aoffset) # B is 1D tensor\n" + "s = tl.min(B) # s is a scalar\n" + "D = tl.load(Cptr + s + Coffset) # s is used as the scalar offset\n"; + + auto reduce_src = op->getOperand(0); + BlockData srcBlock; + parse(reduce_src, srcBlock, loc, rewriter, known); + if (!srcBlock.hasSource()) { + llvm_unreachable(scenarioMessages.c_str()); + } + if (!isa(srcBlock.getSource().getDefiningOp())) { + llvm_unreachable(scenarioMessages.c_str()); + } + + auto reduce_result = op->getResult(0); + auto shaped_ty = dyn_cast(reduce_result.getType()); + auto shape = shaped_ty.getShape(); + auto ops = llvm::map_to_vector(op.getBody()->without_terminator(), + [](Operation &op) { return &op; }); + // Support only the case: scalar = tl.load(1D tensor) + if (shape.size() != 1 || op.getAxis() != 0 || ops.size() != 1 || + !isa(ops.front())) { + llvm_unreachable(scenarioMessages.c_str()); + } + + auto castOp = rewriter.create( + loc, RankedTensorType::get(shape, rewriter.getIndexType()), + reduce_result); + auto offset = castOp.getResult(); + if (data.isEmpty()) { + data.getOffsetsRef().push_back(offset); + data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); + data.getStridesRef().push_back(rewriter.getIndexAttr(1)); + } else { + llvm_unreachable("parseReduce with offset already setup not yet supported"); + } +} + +template +void parseIndirectLoad(OpTy op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) { + // FIXME: assume single result of operation + auto opRes = op->getResult(0); + auto opResTy = opRes.getType(); + std::vector resShape; + if (auto shapedResTy = dyn_cast(opResTy)) { + // For now, we consider this is UnstrucMemAcc because we have no other info. + // Visiting other ops may change the type due to more info. + data.setMemAccVal(MemAccVal::UnstrucMemAcc); + resShape = shapedResTy.getShape().vec(); + } else { + // scalar load means this is used as offset. It is StrucMemAcc. + data.setMemAccVal(MemAccVal::StrucMemAcc); + resShape.push_back(1); + } + for (auto &s : resShape) { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + data.getSizesRef().push_back(rewriter.getIndexAttr(s)); + data.getStridesRef().push_back(rewriter.getIndexAttr(1)); + } + // set the source in BlockData so that we know an indirect-load op exists in + // the chain. + data.setSource(opRes); +} + +void BlockDataParser::rewriteAddPtr( + triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known) { + auto insertPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + BlockData data; + parseAddPtr(op, data, op.getLoc(), rewriter, known); + + if (auto src = data.getSource(); + data.getMemAccTypeRef().isUnstructured() && + !(src && isa_and_nonnull(src.getDefiningOp()))) { + // TODO: Based on more info, try to create a performant IR + rewriteAddPtrToUnstrucMemAcc(op, adaptor, rewriter, data); + LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; }); + return; + } + + if (data.getSizesRef().size() == 0) { + data.getSizesRef().push_back(rewriter.getIndexAttr(1)); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + data.getOffsetsRef().push_back(data.getScalarRef()); + } + + ArrayRef resultShape; + // shape {1,} is stub for single ptr + SmallVector stubScalarTypeShape(1, 1); + if (auto shapedType = dyn_cast(op.getResult().getType())) { + resultShape = shapedType.getShape(); + } else { + assert(data.getRank() == 1); + resultShape = stubScalarTypeShape; + } + + known[op.getResult()] = data; + + // If there are dimensions with size 1 and stride 0, replace 0 stride with the + // product of sizes of all lower dimensions. This avoids creating memref with + // zero stride. + // And here store the unmodified state into known ptrs, since any following + // pointer arithmetic operations should still use the original 0 stride. + auto inferedSize = 1; + for (int i = data.getSizesRef().size() - 1; i >= 0; i--) { + auto strideConst = getConstantIntValue(data.getStridesRef()[i]); + auto sizeConst = getConstantIntValue(data.getSizesRef()[i]); + assert(sizeConst.has_value()); + if (sizeConst.value() == 1 && strideConst && strideConst.value() == 0) { + data.getStridesRef()[i] = rewriter.getIndexAttr(inferedSize); + } + inferedSize *= sizeConst.value(); + } + + if (auto intToPtrOp = + dyn_cast(data.getSourceRef().getDefiningOp())) { + auto rtype = cast(intToPtrOp.getResult().getType()); + auto memrefType = + MemRefType::get({ShapedType::kDynamic}, rtype.getPointeeType()); + auto hivmPointCastOp = rewriter.create( + intToPtrOp.getLoc(), memrefType, ValueRange{intToPtrOp.getSrc()}); + data.setSource(hivmPointCastOp.getResult()); + } + + if (data.hasResElemTy()) { + // Handle bitcast scenario + auto memrefType = dyn_cast(data.getSourceRef().getType()) + .cloneWith(std::nullopt, data.getResElemTyRef()); + UnrealizedConversionCastOp castOp = + rewriter.create( + op.getLoc(), memrefType, data.getSourceRef()); + data.setSource(castOp.getOutputs()[0]); + } + + // ToDo: need to handle module scenario + + memref::ReinterpretCastOp castOp = + data.createCastOp(resultShape, op.getLoc(), rewriter); + Value src = castOp.getResult(); + LLVM_DEBUG({ + llvm::dbgs() << "cast MemRefType:\n"; + castOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + rewriter.replaceOp(op, src); + rewriter.restoreInsertionPoint(insertPoint); +} + +OpFoldResult accumulatePotentialOffsetOnBase( + triton::MakeTensorPtrOp op, Value base, OpFoldResult offset, + ConversionPatternRewriter &rewriter) { + if (auto baseRecast = base.getDefiningOp()) { + assert(isa(op.getBase().getDefiningOp()) && + "base of MakeTensorPtrOp only comes from native ptr or AddPtrOp"); + + return addOpFoldResult(offset, baseRecast.getConstifiedMixedOffset(), + op.getLoc(), rewriter); + } + + return offset; +} + +// Design for load/store boundary_check. +memref::ReinterpretCastOp +createRedundantOp(triton::MakeTensorPtrOp op, + ConversionPatternRewriter &rewriter, + BlockData &data) { + auto loc = op.getLoc(); + // to do boundary_check in tt.load, we need to keep the parent tensor's + // shape info in the IR. + // use the parent tensor's shape to create a cast + auto resultSizes = data.getSizes(); + auto resultOffsets = data.getOffsets(); + data.getSizesRef().clear(); + data.getOffsetsRef().clear(); + data.getSizesRef() = + std::move(llvm::map_to_vector(op.getShape(), [&](Value v) { + return getOpFoldResultOfLayoutInfo(v, rewriter); + })); + + // This redundant ReinterpretCastOp is to describe full tensor_ptr, so each + // dim offset from base is initialized as zero. + SmallVector curOffsets(op.getOffsets().size(), + rewriter.getIndexAttr(0)); + // Just accumulate base potential offset + curOffsets.front() = accumulatePotentialOffsetOnBase( + op, rewriter.getRemappedValue(op.getBase()), curOffsets.front(), + rewriter); + + for (auto offset : curOffsets) { + data.getOffsetsRef().push_back(offset); + } + + SmallVector staticShapes; + SmallVector dynamicShapes; + dispatchIndexOpFoldResults(data.getSizesRef(), dynamicShapes, staticShapes); + auto castOp = data.createCastOp(staticShapes, loc, rewriter); + // restore sizes and offsets + data.getSizesRef().clear(); + for (auto &s : resultSizes) { + data.getSizesRef().push_back(s); + } + data.getOffsetsRef().clear(); + for (auto &offset : resultOffsets) { + data.getOffsetsRef().push_back(offset); + } + return castOp; +} + +void BlockDataParser::rewriteMakeTensorPtrOp( + triton::MakeTensorPtrOp op, Value base, + ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known) { + Location loc = op.getLoc(); + BlockData data; + + auto orderSize = op.getOrder().size(); + if (orderSize > 1) { + // Declaration of llvm::ArrayRef::slice(n, m) + // - Chop off the first N elements of the array, and keep M elements + // in the array. + // Take care that 'm' means chunk length + for (auto [first, second] : + llvm::zip(op.getOrder().slice(0, orderSize - 1), + op.getOrder().slice(1, orderSize - 1))) { + assert(first == second + 1 && "Currently only support default order on block pointers"); + } + } + + // Handle base is defined by tt.bitcast + BlockDataParser::parse(op.getBase(), data, loc, rewriter, known); + if (data.hasResElemTy()) { + auto memrefType = dyn_cast(data.getSourceRef().getType()) + .cloneWith(std::nullopt, data.getResElemTyRef()); + UnrealizedConversionCastOp castOp = + rewriter.create(loc, memrefType, + data.getSourceRef()); + data.setSource(castOp.getOutputs()[0]); + } else { + data.setSource(rewriter.getRemappedValue(op.getBase())); + } + + data.getOffsetsRef() = + std::move(llvm::map_to_vector(op.getOffsets(), [&](Value v) { + return getOpFoldResultOfLayoutInfo(v, rewriter); + })); + data.getStridesRef() = + std::move(llvm::map_to_vector(op.getStrides(), [&](Value v) { + return getOpFoldResultOfLayoutInfo(v, rewriter); + })); + + SmallVector newOffsets; + for (auto [offset, stride] : + llvm::zip(data.getOffsetsRef(), data.getStridesRef())) + newOffsets.push_back(mulOpFoldResult(offset, stride, loc, rewriter)); + + // 1. Consider that current base ptr may comes from `triton::AddPtrOp`, + // which have been converted to `memref::ReinterpretCastOp` with 1D + // shape([1,]) by `AddPtrConverter`. + // 2. While here would also convert `triton::MakeTensorPtrOp` to + // `memref::ReinterpretCastOp`, it will create use-def on double recast + // which means offset&size&stride info of first one will be dropped in terms + // of memref recast op fold specification. + // + // Conclusion with above two: + // Base of MakeTensorPtrOp has been seen as origin base, so it should + // reserve offset of first recast if it exists. + // Here extract the offset of first recast and add it to highest dimension + newOffsets.front() = accumulatePotentialOffsetOnBase( + op, base, newOffsets.front(), rewriter); + + data.getOffsetsRef().clear(); + + for (auto offset : newOffsets) { + data.getOffsetsRef().push_back(offset); + } + + ArrayRef resultShape; + auto pointerType = cast(op.getResult().getType()); + if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { + resultShape = shapedType.getShape(); + data.getSizesRef().clear(); + for (auto dim_size : resultShape) { + data.getSizesRef().push_back( + IntegerAttr::get(IntegerType::get(op.getContext(), 64), dim_size)); + } + } else { + // scalar pointer, should produce a one dimensional memref + SmallVector scalarShape(1, 1); + resultShape = scalarShape; + assert(data.getRank() == 1); + } + + known[op.getResult()] = data; + + // special handling for davinci + // create redundant reinterpret_cast op for record shape info + auto redundantOp = createRedundantOp(op, rewriter, data); + redundantOp->setAttr("tensor_ptr_full_shape", rewriter.getUnitAttr()); + + // create reinterpret_cast op for the target block + data.setSource(redundantOp.getResult()); + auto castOp = data.createCastOp(resultShape, loc, rewriter); + rewriter.replaceOp(op, castOp.getResult()); + + if (nd2nzFlag) { + auto basePtr = castOp.getResult(); + int original_rank = op.getShape().size() + 1; + std::string shapeStr; + + auto baseMemrefType = mlir::dyn_cast(basePtr.getType()); + assert(baseMemrefType && "basePtr is not a memref type"); + auto shape = baseMemrefType.getShape(); + + if (auto memrefType = mlir::dyn_cast(basePtr.getType())) { + for (auto dim : memrefType.getShape()) { + shapeStr += llvm::formatv("_{0}", dim); + } + } + std::string elemTypeName; + Type elemType = baseMemrefType.getElementType(); + if (auto intType = mlir::dyn_cast(elemType)) { + elemTypeName = llvm::formatv("i{0}", intType.getWidth()); + } else if (auto floatType = mlir::dyn_cast(elemType)) { + std::string floatTypeName; + llvm::raw_string_ostream os(floatTypeName); + floatType.print(os); + os.flush(); + elemTypeName = floatTypeName; + } else { + std::string typeName; + llvm::raw_string_ostream os(typeName); + elemType.print(os); + os.flush(); + elemTypeName = typeName; + } + + std::string memrefTypeStr; + llvm::raw_string_ostream os(memrefTypeStr); + baseMemrefType.print(os); + os.flush(); + + std::string laydbgsuffix; + for (char c : memrefTypeStr) { + if ((c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || c == '_' || c == ',' || c == '[' || + c == ']') { + laydbgsuffix += c; + } + } + auto funcName = rewriter.getStringAttr( + llvm::formatv("__hmf_original_shape{0}d{1}_{2}_{3}", original_rank, + shapeStr, elemTypeName, laydbgsuffix)); + MemRefType targetMemrefType = MemRefType::get( + baseMemrefType.getShape(), baseMemrefType.getElementType(), + baseMemrefType.getLayout()); + const int vectorSize = 4; + SmallVector srcElemTys; + for (auto sz : op.getShape()) { + srcElemTys.push_back(sz.getType()); + } + srcElemTys.push_back(targetMemrefType); + Type dstElemTy = rewriter.getNoneType(); + FunctionType hintFuncType = + FunctionType::get(rewriter.getContext(), srcElemTys, {dstElemTy}); + + auto mod = SymbolTable::getNearestSymbolTable(op); + auto extFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(mod, funcName)); + SmallVector args; + for (auto sz : op.getShape()) { + args.push_back(sz); + } + args.push_back(basePtr); + if (!extFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&mod->getRegion(0).front()); + extFunc = rewriter.create(rewriter.getUnknownLoc(), + funcName, hintFuncType); + extFunc.setPrivate(); + extFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(), + UnitAttr::get(rewriter.getContext())); + rewriter.setInsertionPoint(op); + } + rewriter.create(loc, funcName, dstElemTy, args); + } +} + +void BlockDataParser::rewriteAdvanceOp( + triton::AdvanceOp op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known) { + OpBuilder::InsertionGuard insertionGuard(rewriter); + rewriter.setInsertionPoint(op); + auto loc = op.getLoc(); + + BlockData blockData; + parse(op.getOperand(0), blockData, loc, rewriter, known); + + // region [BUGFIX] Add the code block below following the same logic as + // 'BlockDataParser::rewriteAddPtr' function. + known[op.getResult()] = blockData; + auto inferedSize = 1; + for (int i = blockData.getSizesRef().size() - 1; i >= 0; i--) { + auto strideConst = getConstantIntValue(blockData.getStridesRef()[i]); + auto sizeConst = getConstantIntValue(blockData.getSizesRef()[i]); + assert(sizeConst.has_value()); + if (sizeConst.value() == 1 && strideConst && strideConst.value() == 0) { + blockData.getStridesRef()[i] = rewriter.getIndexAttr(inferedSize); + } + inferedSize *= sizeConst.value(); + } + // endregion + + SmallVector incrementOffsets = + llvm::map_to_vector(op.getOffsets(), [&](Value offset) { + return getOpFoldResultOfLayoutInfo(offset, rewriter); + }); + + SmallVector newOffsets; + for (const auto [increment, originalOffset, stride] : + llvm::zip(incrementOffsets, blockData.getOffsetsRef(), + blockData.getStridesRef())) { + auto curDimOffset = + addOpFoldResult(mulOpFoldResult(increment, stride, loc, rewriter), + originalOffset, loc, rewriter); + + newOffsets.push_back(curDimOffset); + } + + blockData.getOffsetsRef().clear(); + + for (auto offset : newOffsets) + blockData.getOffsetsRef().push_back(offset); + + SmallVector scalarShape(1, 1); // Stub shape + ArrayRef resultShape; + auto pointerType = cast(op.getResult().getType()); + + if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { + resultShape = shapedType.getShape(); + } else { + // scalar pointer, should produce a one dimensional memref + resultShape = scalarShape; + assert(blockData.getRank() == 1); + } + + auto newOp = blockData.createCastOp(resultShape, loc, rewriter); + rewriter.replaceOp(op, newOp.getResult()); + + known[newOp.getResult()] = blockData; +} + +void BlockDataParser::rewriteYieldOp( + scf::YieldOp op, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseSet &blockArgIdxSet, ArrayRef iterArgIdxMap, + const llvm::SmallDenseMap &known) { + // Any inserted instruction should be before this yield + OpBuilder::InsertionGuard insertionGuard{rewriter}; + rewriter.setInsertionPoint(op); + + auto adaptor = scf::YieldOp::Adaptor(op); + + SmallVector initArgState; + SmallVector operands; + + operands.reserve(op->getNumOperands()); + for (const auto &[oper, newIterArgIdx]: llvm::zip_equal(adaptor.getOperands(), iterArgIdxMap)) { + if (newIterArgIdx != -1) + operands.push_back(oper); + } + + // For each of the init arg that we added additional Values in for loop, we + // need to add corresponding Values as yield operands. The loop below gathers + // BlockData for those values. + for (auto [i, v] : llvm::enumerate(adaptor.getOperands())) { + if (auto mappedV = rewriter.getRemappedValue(v)) { + // If this value is a tensor of pointers produced by AddPtrOp, + // we should have already converted to a ReinterpretCastOp without + // layout information for the normal cases + if (v.getDefiningOp() || + v.getDefiningOp() || + v.getDefiningOp()) { + if (auto castOp = mappedV.getDefiningOp()) { + v = castOp; + } else { + llvm_unreachable("mapped value defined by an unexpected op"); + } + } else { + // If this value is not a tensor of pointers, we will use the + // mapped value, and rely on the conversion will happen later + // automatically when we legalize loop body. + + // TODO: + // The scenario where a value is a tensor of pointers but not + // produced by AddPtrOp is not supported + if (isa(mappedV.getType()) && + isa( + dyn_cast(mappedV.getType()).getElementType())) + llvm_unreachable("unsupported scenario where a value is a tensor of " + "pointers but not produced by AddPtrOp"); + v = mappedV; + } + } + + if (blockArgIdxSet.find(i) == blockArgIdxSet.end()) + continue; + + auto reintCastOp = v.getDefiningOp(); + assert( + reintCastOp || + (isa(v.getType()) && + isa(dyn_cast(v.getType()).getElementType()))); + + BlockData state; + if (reintCastOp) { + parseReinterpretCast(reintCastOp, state, op.getLoc(), rewriter, known); + } else { + parse(v, state, op.getLoc(), rewriter, known); + } + initArgState.push_back(state); + } + + // For each of the BlockData recorded in the last step, extract value + // that correspond to offset and stride for each dimension and append + // them to yield operands. + for (auto state : initArgState) { + for (auto offset : state.getOffsetsRef()) { + // offsets can be IntAttr zeroes, since reinterpret_cast collapses + // them for the input memref, and the for loop may not update + // offsets other than offsets[0]. Create constants Values for those + // zeroes. + if (isa(offset)) { + auto constOffset = offset.get(); + assert(isa(constOffset) && + dyn_cast(constOffset).getInt() == 0 && + "attribute offsets should be zeroes"); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(0)); + operands.push_back(constOp.getResult()); + } else { + operands.push_back(offset.get()); + } + } + + for (OpFoldResult stride : state.getStridesRef()) { + if (isa(stride)) { + auto constStride = stride.get(); + assert(isa(constStride) && + dyn_cast(constStride).getInt() == 1 && + "attribute strides should be ones"); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(1)); + operands.push_back(constOp.getResult()); + } else { + operands.push_back(stride.get()); + } + } + } + + // Yield is a terminator op that must be at the end of the function + rewriter.setInsertionPointAfter(op); + auto newOp = rewriter.replaceOpWithNewOp(op, operands); + assert(op->getNumResults() == 0); + + LLVM_DEBUG({ + llvm::dbgs() << "new yield:"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +// This function is util function for rewriteLoopOp that +// check if given regionIterArg is used by MemAccOp +bool isUsedByMemAccOp(Value v, int depth = 0) { + for (auto &use: v.getUses()) { + auto *user = use.getOwner(); + if (user->hasAttr(ConverterUtils::discreteAttrName)) + continue; + if (isa(user)) + return true; + if (auto loopOp = dyn_cast(user); + loopOp && !loopOp->hasAttr("ExtractedLoadOrStore")) { + if(isUsedByMemAccOp(loopOp.getTiedLoopRegionIterArg(&use), depth + 1)) + return true; + } else if (auto yieldOp = dyn_cast(user)) { + if (depth && isUsedByMemAccOp(yieldOp->getParentOp()->getResult(use.getOperandNumber()), depth - 1)) + return true; + } + for (auto res: user->getResults()) { + if (isUsedByMemAccOp(res, depth)) + return true; + } + } + return false; +} + +// This function is util function for rewriteLoopOp that +// check if given regionIterArg is used for mask +// Assuming unstructure case is handled in previous pass, +// given value is 1D tensor and stride is one +bool isUsedforMask(Value v, int depth = 0) { + if (auto tensorType = dyn_cast(v.getType()); + !(tensorType && tensorType.getRank() == 1)) + return false; + for (auto &use: v.getUses()) { + auto *user = use.getOwner(); + if ((isa(user) && use.getOperandNumber() == 1) || + (isa(user) && use.getOperandNumber() == 2)) + return true; + if (auto loopOp = dyn_cast(user); + loopOp && !loopOp->hasAttr("ExtractedLoadOrStore")) { + if(isUsedforMask(loopOp.getTiedLoopRegionIterArg(&use), depth + 1)) + return true; + } else if (auto yieldOp = dyn_cast(user)) { + if (depth && isUsedforMask(yieldOp->getParentOp()->getResult(use.getOperandNumber()), depth - 1)) + return true; + } + for (auto res: user->getResults()) { + if (isUsedforMask(res, depth)) + return true; + } + } + return false; +} + +// This function is util function for rewriteLoopOp that create value from data. +// Assume data is structured, and from regionIterArg from LoopLikeOpInterface. +// +// For example, +// +// %7 = scf.for %arg2 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg3 = %4) -> (tensor<128xi32>) : i32 { +// %8 = tt.addptr %5, %arg3 : tensor<128x!tt.ptr>, tensor<128xi32> +// ... +// } +// +// is converted to +// +// %7 = scf.for %arg2 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg3 = %4, %arg4 = %5, %arg5 = %6) -> (tensor<128xi32>) : i32 { +// %scalarOffset = arith.index_cast %arg4 : index to i32 +// %scalarStride = arith.index_cast %arg5 : index to i32 +// ... +// %newRes = arith.addi %offset, %stride : tensor<128xi32> +// %8 = tt.addptr %5, %newRes : tensor<128x!tt.ptr>, tensor<128xi32> +// } +Value createFromData(RankedTensorType resType, const BlockData &data, const Location &loc, OpBuilder &builder, bool isMaskIterArg) { + auto resShape = resType.getShape(); + Value newRes = nullptr; + for (size_t i = 0; i < resShape.size(); i++) { + auto axisType = RankedTensorType::get({resShape[i]}, resType.getElementType()); + auto axisI32Type = RankedTensorType::get({resShape[i]}, builder.getIntegerType(32)); + Value axisValue = builder.create(loc, axisI32Type, 0, resShape[i]); + if (axisType != axisI32Type) { + axisValue = builder.create(loc, axisType, axisValue); + } + Value offset = cast(data.getOffset(i)); + Value offsetValue = builder.create(loc, resType.getElementType(), offset); + offsetValue = builder.create(loc, axisType, offsetValue); + Value stride = cast(data.getStride(i)); + if (!isMaskIterArg) { + Value strideValue = builder.create(loc, resType.getElementType(), stride); + strideValue = builder.create(loc, axisType, strideValue); + axisValue = builder.create(loc, axisValue, strideValue); + } + axisValue = builder.create(loc, axisValue, offsetValue); + + for (size_t j = 0; j < resShape.size(); j++) { + if (i != j) + axisValue = builder.create(loc, axisValue, j); + } + axisValue = builder.create(loc, resType, axisValue); + if (newRes) { + newRes = builder.create(loc, newRes, axisValue); + } else { + newRes = axisValue; + } + } + return newRes; +} + +void BlockDataParser::rewriteLoopOp( + LoopLikeOpInterface op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known) { + SmallVector newInitArgs; + SmallVector iterArgIdxMap; + SmallVector maskIterArgs; + int64_t argCnt = 0; + + SmallVector, 5> initArgIndexIfBlockData; + SmallVector, 5> knownPtrsTmp; + llvm::SmallDenseSet blockArgIdxSet; + + // Create a new list of init args + for (auto [i, arg] : llvm::enumerate(op.getInits())) { + auto mappedV = rewriter.getRemappedValue(arg); + memref::ReinterpretCastOp reintCastOp; + maskIterArgs.push_back(false); + + // If this init arg is supposed to be remapped, use the remapped + // value instead. + // In addition, if this init arg is a memref created by a reinterpret_cast + // or a tensor of index, there is a chance that it will be used in addptr. + // Create BlockData for each such init arg. + if (mappedV) { + // TODO: + // Passing a block argument pointer directly into a for loop not + // supported. + assert(!(isa(mappedV) && + isa(mappedV.getType())) && + "cannot take pointer block argument as init arg for for loop"); + if (auto reinterpretCastOp = mappedV.getDefiningOp()) { + // Record memref::ReinterpretCastOp + reintCastOp = reinterpretCastOp; + newInitArgs.push_back(mappedV); + iterArgIdxMap.push_back(argCnt++); + } else if (auto defOp = op.getYieldedValues()[i].getDefiningOp(); + (defOp && defOp->hasAttr("MetaUse"))) { + // When argument is MetaUse in the loop, + // It is removed in iter_args + newInitArgs.push_back(nullptr); + iterArgIdxMap.push_back(-1); + } else { + newInitArgs.push_back(mappedV); + iterArgIdxMap.push_back(argCnt++); + } + } else { + newInitArgs.push_back(arg); + iterArgIdxMap.push_back(argCnt++); + } + + auto indexTensor = + isa(arg.getType()) && + isa(cast(arg.getType()).getElementType()) && + isUsedByMemAccOp(op.getRegionIterArgs()[i]); + + // Handle memref::ReinterpretCastOp and tensor specially + if (!reintCastOp && !indexTensor) + continue; + + BlockData data; + if (reintCastOp) { + parseReinterpretCast(reintCastOp, data, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + } else { + parse(arg, data, op.getLoc(), rewriter, + llvm::SmallDenseMap(0)); + } + + maskIterArgs[i] = indexTensor && isUsedforMask(op.getRegionIterArgs()[i]); + + // Record the BlockData for later processing + initArgIndexIfBlockData.push_back(std::make_pair(i, data)); + } + + // Set insertion point to be before the for loop for new variables passed + // into the new loop. + auto origIp = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + // For each of the BlockData recorded in the last step, insert new + // instructions to describe offset and stride for each dimension and append + // them to init args + for (auto [i, data] : initArgIndexIfBlockData) { + // For each dimension, if the corresponding offset and stride is an + // integer attribute, create a constant value and append them at the + // end of init arg list, which is prepared for calculate layout info with + // loop interation index + for (auto &dataOffset : data.getOffsetsRef()) { + if (isa(dataOffset)) { + auto constDataOffset = dataOffset.get(); + assert(isa(constDataOffset)); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr( + dyn_cast(constDataOffset).getInt())); + newInitArgs.push_back(constOp.getResult()); + dataOffset = constOp.getResult(); + } else { + assert(isa(dataOffset.get().getType())); + newInitArgs.push_back(dataOffset.get()); + } + } + + for (auto &dataStride : data.getStridesRef()) { + if (isa(dataStride)) { + auto constDataStride = dataStride.get(); + assert(isa(constDataStride)); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr( + dyn_cast(constDataStride).getInt())); + newInitArgs.push_back(constOp.getResult()); + dataStride = constOp.getResult(); + } else { + assert(isa(dataStride.get().getType())); + newInitArgs.push_back(dataStride.get()); + } + } + + // Note that we want the knownPtrs to be indexed by block arg, but we + // only have index for now. Also, the blockdata we record is the init + // arg, but want to to use newly created block arg. These block args + // are not created yet. We will translate this mapping later. + knownPtrsTmp.push_back(std::make_pair(i, data)); + blockArgIdxSet.insert(i); + + // If the original init arg is a memref produced by reinterpret_cast, + // create a new memref using new strides and offsets created above. + // This produces a canonicalized memref, which will match what the + // for loop generates if it modifies the memref. E.g., original + // reinterpret_cast can produce a memref with const stride: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + + // s0 + d1 + // * s1)>> + // The new reinterpret_cast will always have dynamic stride and + // offset: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + // + s0 + d1 * s2)>> + if (newInitArgs[i] && newInitArgs[i].getDefiningOp()) { + SmallVector resultShape; + for (auto size : data.getSizesRef()) { + auto constSize = getConstantIntValue(size); + assert(constSize && "expected constant size"); + resultShape.push_back(constSize.value()); + } + + // In current block data layout info, strides and offsets must be dynamic + // value + auto castOp = data.createCastOp(resultShape, op.getLoc(), rewriter); + if (resultShape.size() > 1) { + auto originalOffset = dyn_cast(data.getOffsetsRef()[0]); + for (auto &offsets : newInitArgs) { + if (offsets == originalOffset) { + offsets = castOp.getOffsets()[0]; + break; + } + } + data.getOffsetsRef()[0] = castOp.getOffsets()[0]; + } + + LLVM_DEBUG({ + llvm::dbgs() << "new reinterpret_cast with dynamic sizes " + "and offsets:"; + castOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + newInitArgs[i] = castOp.getResult(); + } + } + + rewriter.restoreInsertionPoint(origIp); + IRMapping mapping; + + // Create a new LoopOp that uses updated init args and same loop body + LoopLikeOpInterface newOp; + auto newInits = to_vector(make_filter_range(newInitArgs, [](Value v) { return v != nullptr; })); + auto commonBodyBuilder = [&](OpBuilder &b, Location loc, ValueRange newRegionArgs, Region ®ion, Block::BlockArgListType regionArgs) { + auto newArgIter = newRegionArgs.begin(); + for (const auto &[initArg, regionArg, newInitArg]: llvm::zip(op.getInits(), regionArgs, newInitArgs)) { + if (newInitArg) { + mapping.map(initArg, newInitArg); + mapping.map(regionArg, *newArgIter); + ++newArgIter; + } + } + + // Convert the book-keeping data structure to use the correct key and value. + // Key is converted from init arg index to newly created block arg, and + // Value's BlockData fields are converted from init arg to newly created block + // arg + for (auto [i, data] : knownPtrsTmp) { + for (auto &offset: data.getOffsetsRef()) { + offset = *newArgIter; + ++newArgIter; + } + + for (auto &stride: data.getStridesRef()) { + stride = *newArgIter; + ++newArgIter; + } + + auto regionArg = regionArgs[i]; + auto key = mapping.lookupOrNull(regionArg); + if (!key) { + // Create MetaUse regionArg from computed offset and stride data + key = createFromData(cast(regionArg.getType()), data, op.getLoc(), rewriter, maskIterArgs[i]); + mapping.map(regionArg, key); + } + known.insert(std::make_pair(key, data)); + } + + for (auto &bodyOp : region.getOps()) + b.clone(bodyOp, mapping); + }; + + if (auto forOp = dyn_cast(op.getOperation())) { + newOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInits, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + mapping.map(forOp.getInductionVar(), iv); + commonBodyBuilder(b, loc, args, forOp.getRegion(), op.getRegionIterArgs()); + }); + } else if (auto whileOp = dyn_cast(op.getOperation())) { + auto resultTypes = map_to_vector(newInits, [](auto v) { return v.getType(); }); + newOp = rewriter.create( + whileOp.getLoc(), resultTypes, newInits, + [&](OpBuilder &b, Location loc, ValueRange args) { + commonBodyBuilder(b, loc, args, whileOp.getBefore(), whileOp.getBeforeArguments()); + }, + [&](OpBuilder &b, Location loc, ValueRange args) { + commonBodyBuilder(b, loc, args, whileOp.getAfter(), whileOp.getAfterArguments()); + }); + } + + // Replace only the results that correspond to the original scf.for + auto newResultIter = newOp->result_begin(); + rewriter.setInsertionPointAfter(newOp); + for (const auto &[res, regionArg, newIterArgIdx, mask]: llvm::zip_equal(op->getResults(), op.getRegionIterArgs(), iterArgIdxMap, maskIterArgs)) { + if (newIterArgIdx != -1) { + rewriter.replaceAllUsesWith(res, *newResultIter); + ++newResultIter; + } else { + auto key = mapping.lookup(regionArg); + auto data = known.at(key); + for (auto &offset : data.getOffsetsRef()) + offset = newOp.getTiedLoopResult(cast(offset.get())); + for (auto &stride : data.getStridesRef()) + stride = newOp.getTiedLoopResult(cast(stride.get())); + auto newRes = createFromData(cast(regionArg.getType()), data, op.getLoc(), rewriter, mask); + rewriter.replaceAllUsesWith(res, newRes); + } + } + rewriter.eraseOp(op); + + // Update the loop body. Manually invoke the rewrite logic on addptr and yield + // in the loop body, so we can take advantage of the states we built up + for (auto *region : newOp.getLoopRegions()) { + for (auto &bodyOp : region->getOps()) { + if (auto addptrOp = dyn_cast(bodyOp)) { + // FIXME: Constructed adaptor here does not hold the transformed op info. + auto adaptor = triton::AddPtrOp::Adaptor(addptrOp); + rewriteAddPtr(addptrOp, adaptor, rewriter, known); + } else if (auto advanceOp = dyn_cast(bodyOp)) { + rewriteAdvanceOp(advanceOp, rewriter, known); + } else if (auto makeTensorPtrOp = dyn_cast(bodyOp)) { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(makeTensorPtrOp); + rewriteMakeTensorPtrOp(makeTensorPtrOp, rewriter.getRemappedValue(makeTensorPtrOp.getBase()), rewriter, known); + } else if (auto loopOp = dyn_cast(bodyOp); + loopOp && !loopOp->hasAttr("ExtractedLoadOrStore")) { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(loopOp); + rewriteLoopOp(loopOp, rewriter, known); + } + } + } + + if (!op.getRegionIterArgs().empty()) { + auto yieldOp = cast(newOp.getLoopRegions().back()->back().getTerminator()); + rewriteYieldOp(yieldOp, rewriter, blockArgIdxSet, iterArgIdxMap, known); + } + + LLVM_DEBUG({ + llvm::dbgs() << "new loop\n"; + newOp.getOperation()->print(llvm::dbgs(), + OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); +} + +/// @brief Rewrite the triton::AddPtrOp to handle unstructured memory access. +/// @param op The triton::AddPtrOp to be rewritten. +/// @param adaptor The adaptor of the triton::AddPtrOp, used to get operands. +/// @param rewriter The pattern rewriter used to modify the IR. +/// @param data The BlockData containing information about the memory access. +void BlockDataParser::rewriteAddPtrToUnstrucMemAcc( + triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, BlockData &data) { + auto loc = op.getLoc(); + auto &offsets = data.getOffsetsRef(); + auto &blockSizes = data.getSizesRef(); + auto &strides = data.getStridesRef(); + Value ptrOffset = adaptor.getOffset(); + Value zeroIdx = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value oneIdx = + rewriter.create(loc, rewriter.getIndexAttr(1)); + auto addptrRes = op.getResult(); + assert(addptrRes.hasOneUse() && "Invalid: tt.addptr has multiple users"); + auto loadOp = *(addptrRes.user_begin()); + + // Prepare empty tensor for loop based scalar load + // FIXME: We use cast here because addptr must return tensor>. + // True? + auto resTy = cast(addptrRes.getType()); + auto resEPtrTy = resTy.getElementType(); + auto resETy = cast(resEPtrTy).getPointeeType(); + Value loaded = rewriter.create(loc, blockSizes, resETy); + SmallVector initArgs; + initArgs.push_back(loaded); + + SmallVector forLBs; + SmallVector forUBs; + SmallVector forSteps; + for (auto &s : offsets) { + forLBs.push_back(zeroIdx); + } + for (auto &s : blockSizes) { + forUBs.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, s)); + } + for (auto &s : strides) { + forSteps.push_back(oneIdx); + } + SmallVector ivs; + OpBuilder builder(op); + auto loop = createNestedLoops( + builder, loc, 0, blockSizes.size(), forLBs, forUBs, forSteps, ivs, + initArgs, + [&](OpBuilder &bB, Location bLoc, SmallVector &allIVs, + ValueRange iterArgs) { + OpBuilder::InsertionGuard g(bB); + bB.setInsertionPointToStart(bB.getBlock()); + + Value scalarOffsetRaw = + bB.create(bLoc, ptrOffset, allIVs); + Value scalarOffset = bB.create( + bLoc, bB.getIndexType(), scalarOffsetRaw); + // Replace offset & size. Only single element. + data.getOffsetsRef().clear(); + data.getOffsetsRef().push_back(scalarOffset); + data.getSizesRef().clear(); + data.getSizesRef().push_back(bB.getIndexAttr(1)); + data.getStridesRef().clear(); + data.getStridesRef().push_back(bB.getIndexAttr(1)); + memref::ReinterpretCastOp castOp = data.createCastOp({1}, bLoc, bB); + rewriter.replaceOp(op, castOp); + // Move tt.load using this tt.addptr into this block + loadOp->moveAfter(castOp); + loadOp->setAttr("IndirectLoad", UnitAttr::get(op.getContext())); + bB.create(bLoc, iterArgs); + }); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt new file mode 100644 index 000000000..5730838d5 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt @@ -0,0 +1,30 @@ +add_triton_library(TritonToLinalg + TritonToLinalgPass.cpp + LoadStoreConverter.cpp + FunctionConverter.cpp + ArgMinMaxConverter.cpp + DescriptorConverter.cpp + TritonOpConverter.cpp + BlockPtrAnalysis.cpp + MaskAnalysis.cpp + UseAnalysis.cpp + + DEPENDS + TritonToLinalgConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms + TritonAnalysis + MLIRTritonNPUUtils + MLIRSCFTransforms + MLIRLinalgTransforms +) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/DescriptorConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/DescriptorConverter.cpp new file mode 100644 index 000000000..7499ad72f --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/DescriptorConverter.cpp @@ -0,0 +1,165 @@ +#include "TritonToLinalg/DescriptorConverter.h" +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "TritonToLinalg/MaskAnalysis.h" +#include "TritonToLinalg/TritonOpConverter.h" +#include "TritonToLinalg/TritonToLinalgPass.h" +#include "Utils/Utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace DescriptorConverter { +using namespace mlir; +using namespace triton; + +bool hasATensorDescriptorType(mlir::TypeRange types) +{ + return llvm::any_of(types, [](mlir::Type t) { return llvm::isa(t); }); +} + +/** + * @brief Filter out operand segment sizes from the list of attributes since + * this attribute is operation specific and shouldn't be set arbitrarily. + */ +mlir::SmallVector filterSegmentSizes(mlir::ArrayRef attrs) +{ + mlir::SmallVector ret; + llvm::copy_if(attrs, std::back_inserter(ret), [](const NamedAttribute &attr) { + auto attrName = attr.getName().getValue(); + return attrName != "operandSegmentSizes"; + }); + return ret; +} + +Descriptor unpackDescriptor(TensorDescType type, Value desc, ConversionPatternRewriter &rewriter) +{ + auto makeDescOp = desc.getDefiningOp(); + assert(makeDescOp && "Descriptor must be defined by MakeTensorDescOp"); + + Descriptor res; + + // 直接回溯处理的 tt.make_tensor_descriptor + res.base = makeDescOp.getBase(); + for (auto s : makeDescOp.getShape()) { + res.shape.push_back(rewriter.createOrFold(makeDescOp.getLoc(), rewriter.getI64Type(), s)); + } + for (auto st : makeDescOp.getStrides()) { + res.strides.push_back(rewriter.createOrFold(makeDescOp.getLoc(), rewriter.getI64Type(), st)); + } + + return res; +} + +SmallVector computeOrder(ArrayRef shape) +{ + SmallVector order; + int rank = shape.size(); + order.reserve(rank); + // 默认采用逆序 [dims - 1, ..., 0] + for (int i = rank - 1; i >= 0; --i) { + order.push_back(i); + } + return order; +} + +LogicalResult DescriptorLoadConverter::matchAndRewrite(triton::DescriptorLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const +{ + auto loc = op.getLoc(); + const auto blockShape = op.getDesc().getType().getBlockType().getShape(); + auto descTy = op.getDesc().getType(); + auto indices = op.getIndices(); + + // 1. 解包 descriptor + auto desc = unpackDescriptor(descTy, adaptor.getDesc(), rewriter); + + // 2. 新增 make_tensor_ptr + SmallVector tensorShapeValues; + for (auto dim : blockShape) { + tensorShapeValues.push_back(static_cast(dim)); + } + Value tensorPtr = rewriter.create(loc, + desc.base, // 基址 + desc.shape, // 形状 + desc.strides, // 步长 + indices, // 偏移 + tensorShapeValues, // tensorShape + computeOrder(blockShape) // 使用动态计算的 order + ); + // 3. 替换 tt.load 操作 + auto newLoad = rewriter.replaceOpWithNewOp(op, descTy.getSignlessBlockType(), tensorPtr); + + // 保留原始操作的其他属性 + newLoad->setAttrs(filterSegmentSizes(op->getAttrs())); + + return success(); +} + +LogicalResult DescriptorStoreConverter::matchAndRewrite(triton::DescriptorStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const +{ + auto loc = op.getLoc(); + const auto blockShape = op.getDesc().getType().getBlockType().getShape(); + auto descTy = op.getDesc().getType(); + auto indices = op.getIndices(); + + // 1. 解包 descriptor + auto desc = unpackDescriptor(descTy, adaptor.getDesc(), rewriter); + + // 2. 新增 make_tensor_ptr + SmallVector tensorShapeValues; + for (auto dim : blockShape) { + tensorShapeValues.push_back(static_cast(dim)); + } + Value tensorPtr = rewriter.create(loc, + desc.base, // 基址 + desc.shape, // 形状 + desc.strides, // 步长 + indices, // 偏移 + tensorShapeValues, // tensorShape + computeOrder(blockShape) // 使用动态计算的 order + ); + + // 3. 替换 tt.store 操作 + Value valueToStore = adaptor.getSrc(); + + auto maskType = RankedTensorType::get(blockShape, rewriter.getI1Type()); + rewriter.create(loc, DenseElementsAttr::get(maskType, true)); + + // 创建属性 + auto boundaryCheck = rewriter.getDenseI32ArrayAttr({}); // 空的边界检查 + auto cacheModifier = triton::CacheModifierAttr::get(rewriter.getContext(), triton::CacheModifier::NONE); + auto evictionPolicy = triton::EvictionPolicyAttr::get(rewriter.getContext(), triton::EvictionPolicy::NORMAL); + + // 创建 store 操作并替换原始操作 + auto newStore = rewriter.replaceOpWithNewOp(op, // 要替换的操作 + tensorPtr, // 指针 + valueToStore, // 要存储的值 + nullptr, // 掩码 + boundaryCheck, // 边界检查 + cacheModifier, // 缓存修饰符 + evictionPolicy // 驱逐策略 + ); + + // 保留原始操作的其他属性 + newStore->setAttrs(filterSegmentSizes(op->getAttrs())); + return success(); +} + +} // namespace DescriptorConverter diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp new file mode 100644 index 000000000..d1f70b04b --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/FunctionConverter.cpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + + +#include "TritonToLinalg/FunctionConverter.h" + +namespace FunctionConverter { +using namespace mlir; +using namespace triton; + +LogicalResult GetProgramIDConverter::matchAndRewrite( + triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto axis = (uint32_t)op.getAxis(); + assert(axis < GetProgramIDConverter::LAUNCH_GRID_RANK && + "Invalid axis for GetProgramIdOp"); + auto func = op->getParentOfType(); + auto numArgs = func.getNumArguments(); + auto id = func.getArgument(numArgs - GetProgramIDConverter::LAUNCH_GRID_RANK + + axis); + rewriter.replaceOp(op, id); + return success(); +} + +LogicalResult GetNumProgramsConverter::matchAndRewrite( + triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto axis = (uint32_t)op.getAxis(); + assert(axis < GetNumProgramsConverter::LAUNCH_GRID_RANK && + "Invalid axis for GetNumProgramsOp"); + auto func = op->getParentOfType(); + auto numArgs = func.getNumArguments(); + auto id = func.getArgument( + numArgs - GetNumProgramsConverter::LAUNCH_GRID_RANK * 2 + axis); + rewriter.replaceOp(op, id); + return success(); +} +} // namespace FunctionConverter \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp new file mode 100644 index 000000000..87443db77 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp @@ -0,0 +1,979 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "TritonToLinalg/LoadStoreConverter.h" +#include "TritonToLinalg/MaskAnalysis.h" +#include "TritonToLinalg/TritonToLinalgPass.h" +#include "Utils/InterleaveOptimization.h" +#include "Utils/Utils.h" +#include "bishengir/Dialect/Annotation/IR/Annotation.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" + +#include "llvm/Support/Debug.h" + +#include +#include +#include + +#define DEBUG_TYPE "triton-load-store-converter" + +namespace LoadStoreConverter { +using namespace mlir; +using namespace triton; + +const std::string MayImplicitTransposeWithLastAxisTAG = "MayImplicitTransposeWithLastAxis"; + +LogicalResult +AddPtrConverter::matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + llvm::SmallDenseMap known; + BlockDataParser::rewriteAddPtr(op, adaptor, rewriter, known); + return success(); +} + +LogicalResult LoadConverter::toTensorAndReplace( + triton::LoadOp &op, RankedTensorType &tensorType, memref::AllocOp &allocOp, + bool mayImplicitTransposeWithLastAxis, const Location &loc, ConversionPatternRewriter &rewriter) const { + + Value loadedTensor = rewriter.create( + loc, tensorType, allocOp, true, true); + if(mayImplicitTransposeWithLastAxis){ + auto markOp = rewriter.create(loc, loadedTensor); + markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext())); + } + rewriter.replaceOp(op, loadedTensor); + return success(); +} + +/// @brief Check whether the triton::LoadOp has been modified to the specified +/// state by the AddPtrConverter. +/// @param op The triton::LoadOp operation to be checked. +/// @return Return success if the operation conforms to the specified state; +/// otherwise, return failure. +LogicalResult +LoadConverter::checkModifiedByAddPtrConverter(triton::LoadOp &op) const { + if (!isa(op->getParentOp())) { + return failure(); + } + if (!op->hasAttr("IndirectLoad")) { + return failure(); + } + auto ptrOp = op.getPtr().getDefiningOp(); + auto ptrBlock = ptrOp->getBlock(); + auto opBlock = op->getBlock(); + if (ptrBlock == opBlock) { + return failure(); + } + + return success(); +} + +/// @brief Continue to modify the triton::LoadOp from the state modified by the +/// AddPtrConverter. +/// @param op The triton::LoadOp operation to be processed. +/// @param adaptor The adaptor for the operation, used to obtain operands. +/// @param rewriter The pattern rewriter used to rewrite the operation. +/// @return Return success if the operation is successful; otherwise, return +/// failure. +LogicalResult LoadConverter::continueModifyFromAddPtrConverter( + triton::LoadOp &op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto forOp = op->getParentOfType(); + Operation *firstOp = &forOp.getBody()->front(); + auto extractOp = cast(firstOp); + auto ivs = extractOp.getIndices(); + // Single iterArg which is inserted by AddPtrConverter. + auto iterArg = forOp.getRegionIterArg(0); + auto ptr = adaptor.getPtr(); + + rewriter.setInsertionPointAfter(op); + Value castVal = ptr.getDefiningOp(); + Value idxZero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value loadVal = + rewriter.create(loc, castVal, ValueRange{idxZero}); + Value insertedVal = + rewriter.create(loc, loadVal, iterArg, ValueRange{ivs}); + // a yield op is already created by AddPtrConverter. + // so we need to replace it with a new yield op. + Operation *terminator = forOp.getBody()->getTerminator(); + scf::YieldOp oldYieldOp = cast(terminator); + auto yieldOp = rewriter.create(loc, ValueRange{insertedVal}); + rewriter.replaceOp(oldYieldOp, yieldOp); + // Now the scf.for is complete, we can replace tt.load with it. + auto rank = cast(op.getResult().getType()).getShape().size(); + Operation *rootForOp = op; + while (rank != 0) { + rank--; + rootForOp = rootForOp->getParentOfType(); + } + rewriter.replaceOp(op, rootForOp); + LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(rootForOp) << "\n"; }); + return success(); +} + +void LoadConverter::fillTensorWithOtherForMaskScenario( + Value other, memref::AllocOp localMem, ArrayRef maskDim, + ConversionPatternRewriter &rewriter) const { + auto loc = localMem->getLoc(); + MemRefType originalType = localMem.getType(); + assert(originalType.hasStaticShape() && "only support static shape"); + assert(originalType.getRank() == maskDim.size() && + "shape and mask must have same rank"); + + auto fillFlag = + rewriter.create(loc, rewriter.getBoolAttr(false)) + .getResult(); + + for (size_t i = 0; i < originalType.getShape().size(); ++i) { + // Use dynamic value to judge whether overstep boundary + auto shapeVal = rewriter.create( + loc, rewriter.getIndexAttr(originalType.getDimSize(i))); + + Value maskDimVal; + if (isa(maskDim[i])) + maskDimVal = rewriter.create( + loc, cast(maskDim[i].get())); + else + maskDimVal = maskDim[i].get(); + + auto curCmp = rewriter.create(loc, arith::CmpIPredicate::slt, + maskDimVal, shapeVal); + + fillFlag = rewriter.create(loc, fillFlag, curCmp.getResult()) + .getResult(); + } + + rewriter.create( + loc, fillFlag, [&](OpBuilder &builder, Location loc) { + builder.create(loc, ValueRange{other}, + ValueRange{localMem}); + builder.create(loc); + }); +} + +LoadConverter::LoadConverter(MLIRContext *context) + : OpConversionPattern(context) {} + +LogicalResult +LoadConverter::matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Check if tt.load is modified by AddPtrConverter to a specified state. + if (checkModifiedByAddPtrConverter(op).succeeded()) { + return continueModifyFromAddPtrConverter(op, adaptor, rewriter); + } + + auto ptr = adaptor.getPtr(); + auto mask = op.getMask(); + auto other = op.getOther(); + auto loc = op.getLoc(); + + // handling scalar + if (!isa(op.getResult().getType())) { + auto scalarMemref = + BlockDataParser::getScalarMemRef(op.getPtr(), ptr, loc, rewriter); + auto resTy = op.getResult().getType(); + auto idxZero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + auto allocOp = rewriter.create(loc, MemRefType::get({1}, resTy)); + auto srcSubView = mlir::ConverterUtils::makeSubViewOp(ptr, {rewriter.getIndexAttr(1)}, loc, rewriter); + rewriter.create(loc, srcSubView, allocOp); + auto loadedTensor = rewriter.create( + loc, RankedTensorType::get({1}, resTy), allocOp, true /* restrict */, true /* writable */); + auto loadedValue = rewriter.create(loc, resTy, loadedTensor, ValueRange({idxZero})).getResult(); + if (mask && other) { + mask = rewriter.create(loc, RankedTensorType::get({1}, mask.getType()), mask); + loadedValue = rewriter.create(loc, RankedTensorType::get({1}, loadedValue.getType()), loadedValue); + other = rewriter.create(loc, RankedTensorType::get({1}, other.getType()), other); + loadedValue = rewriter.create(loc, mask, loadedValue, other); + rewriter.replaceOpWithNewOp(op, loadedValue, ValueRange({idxZero})); + } else { + rewriter.replaceOp(op, loadedValue); + } + return success(); + } + + int64_t lastStride=-1; + if (isa(ptr)) { + auto u = ptr; + while (auto blkArg = dyn_cast(u)) { + if (auto forOp = dyn_cast(blkArg.getOwner()->getParentOp())) { + auto prt = forOp->getOperand(3+blkArg.getArgNumber()-1); + u = prt; + } else { + u=nullptr; + break; + } + } + if (u && isa(u.getDefiningOp())) { + auto ret = mlir::ConverterUtils::getLastStrideOfReinterpretCastOp(dyn_cast(u.getDefiningOp())); + if (ret.has_value()) lastStride = *ret; + } + } + + // handling no mask + auto memRefType = dyn_cast(ptr.getType()); + if (!memRefType) { + return rewriter.notifyMatchFailure( + op, "LoadOp expects a memref, not a memref of pointers"); + } + bool mayImplicitTransposeWithLastAxis = (existDotFlag) && (!op->hasAttr(ConverterUtils::GeneratedByMakeTensorPtrTAG)) && + (lastStride != 1 && mlir::ConverterUtils::isaPermutedMemRefType(memRefType)); + auto memRefShape = memRefType.getShape(); + auto memRefElementType = memRefType.getElementType(); + + auto allocOp = rewriter.create( + loc, MemRefType::get(memRefShape, memRefElementType)); + + auto tensorType = RankedTensorType::get(memRefShape, memRefElementType); + // boundary check + auto boundaryCheck = op.getBoundaryCheck(); + if (!boundaryCheck.empty()) { + auto boundarySizes = mlir::ConverterUtils::getBoundarySizes( + boundaryCheck, /*remapped*/ ptr, loc, rewriter); + // handle the padding + auto padding = op.getPadding(); + if (padding.has_value()) { + TypedAttr padAttr = rewriter.getZeroAttr(memRefElementType); + // triton already ensure only NAN and ZERO are passed in + if (padding.value() == triton::PaddingOption::PAD_NAN) { + // FIXME: Why NaN requires elemTy to be non-int or non-index? + assert(!memRefElementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + cast(padAttr).getValue().getSemantics()); + padAttr = rewriter.getFloatAttr(memRefElementType, apNaN); + } + auto padVal = rewriter.create(loc, padAttr); + + fillTensorWithOtherForMaskScenario(padVal, allocOp, boundarySizes, + rewriter); + } + + auto srcSubView = + mlir::ConverterUtils::makeSubViewOp(ptr, boundarySizes, loc, rewriter); + auto dstSubview = mlir::ConverterUtils::makeSubViewOp( + allocOp, boundarySizes, loc, rewriter); + rewriter.create(loc, srcSubView, dstSubview); + if (mayImplicitTransposeWithLastAxis) { + auto markOp = rewriter.create(loc, dstSubview); + markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext())); + } + return this->toTensorAndReplace(op, tensorType, allocOp, mayImplicitTransposeWithLastAxis, loc, rewriter); + } + + if (!mask) { + assert(!other && "can not input 'other' when 'mask' is not set"); + if (auto unrealizedCastOp = + ptr.getDefiningOp()) { + // TODO : not support handle associate with "module" + // hint : can be handled in Linearize + op->emitError("meeting unexpected UCC in LoadConverter!"); + return failure(); + } else { + // If last dimension stride equals 2, try deinterleave optimization. + auto [ptrStrides, ptrOffsets] = getStridesAndOffset(memRefType); + if (ptrStrides.back() == 2 && (memRefShape.back() % 2 == 0) && + mlir::triton::DeinterleaveStatusOptimization(op, adaptor, rewriter) + .succeeded()) { + return success(); + } + rewriter.create(loc, ptr, allocOp); + if (mayImplicitTransposeWithLastAxis) { + auto markOp = rewriter.create(loc, allocOp); + markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext())); + } + } + + return this->toTensorAndReplace(op, tensorType, allocOp, mayImplicitTransposeWithLastAxis, loc, rewriter); + } + + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + if (isContMask.failed()) { + return rewriter.notifyMatchFailure( + op, "can not lower uncontinuout masked loads"); + } + + if (other) { + auto scalarOther = + mlir::ConverterUtils::getScalarValue(other, loc, rewriter); + assert( + scalarOther && + "other value used in masked load produced by unsupported instruction!"); + + fillTensorWithOtherForMaskScenario(scalarOther, allocOp, mstate.dims, + rewriter); + } + + // To enable deinterleave optimization with mask load, mask state along last + // dimension couldn't be split, which means `dims.back()` must be equal to + // origin type last dimension constant size and `offsets.back()` must be 0. + // + // The basis is that last dimension range comparison would generate + // unaccepted discontinuous mask. + if (mstate.getRank() == memRefType.getRank() && + isConstantIntValue(mstate.offsets.back(), 0) && + isConstantIntValue(mstate.dims.back(), memRefType.getShape().back())) { + auto [ptrStrides, ptrOffsets] = getStridesAndOffset(memRefType); + if (ptrStrides.back() == 2 && (memRefType.getShape().back() % 2 == 0) && + DeinterleaveStatusWithMaskOptimization(op, adaptor, rewriter, mstate, + allocOp) + .succeeded()) { + return success(); + } + } + + if (auto unrealizedCastOp = ptr.getDefiningOp()) { + // TODO : not support handle associate with "module" + // hint : can be handled in Linearize + op->emitError("meeting unexpected UCC in LoadConverter!"); + return failure(); + } else { + memref::SubViewOp srcSubView = mstate.getSubview(ptr, loc, rewriter); + memref::SubViewOp dstSubView = mstate.getSubview(allocOp, loc, rewriter); + rewriter.create(loc, srcSubView, dstSubView); + if (mayImplicitTransposeWithLastAxis) { + auto markOp = rewriter.create(loc, dstSubView); + markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext())); + } + } + return this->toTensorAndReplace(op, tensorType, allocOp, mayImplicitTransposeWithLastAxis, loc, rewriter); +} + +AtomicRMWConverter::AtomicRMWConverter(MLIRContext *context) + : OpConversionPattern(context) {} + +// lowering tt.atomicRMW to linalg.generic +// If atomic op's return value is used by other op as it's the old value stored +// at the ptrwe will use tt.load to get it +// +// example: +// input: +// %return_value = tt.atomic_rmw fadd, acq_rel, gpu, +// %output_memref, %input_tensor, %mask : +// (tensor<256x!tt.ptr>, tensor<256xf32>, tensor<256xi1>) +// -> tensor<256xf32> +// +// output: +// memref.copy %output_memref, %ub_buf : memref to memref +// %17 = bufferization.to_tensor %alloc_3 restrict writable : memref<256xf32> +// linalg.generic +// {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} +// ins(%output_memref, %masked_input_memref : memref, memref) +// outs(%subview_2 : memref) +// attrs = {GenericAtomicRMW = "fadd", MemSemantic = "acq_rel", +// MemSyncScope = "gpu"} { +// ^bb0(%in: f32, %in_9: f32, %out: f32): +// %25 = arith.addf %in, %in_9 : f32 +// linalg.yield %25 : f32 +// } +LogicalResult +AtomicRMWConverter::matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // If the result of AtomicRMWOp is not used, we don't need to load the old + // data stored at the ptr + auto ptr = adaptor.getPtr(); + auto val = op.getVal(); + auto loc = op.getLoc(); + + auto resType = dyn_cast(op.getResult().getType()); + if (!resType) { + return rewriter.notifyMatchFailure( + op, "atomicRMWConverter: scalar will be handled by " + "ScalarAtomicRMWCanonicalizer"); + } + + auto rmwOp = op.getAtomicRmwOp(); + if (rmwOp == triton::RMWOp::UMAX || rmwOp == triton::RMWOp::UMIN) { + return rewriter.notifyMatchFailure( + op, "AtomicRMWConverter: unsupported atomic kind for now"); + } + + // 1. Simple case where no mask is used. + auto type = dyn_cast(ptr.getType()); + if (!type) { + // Seen when implicit broadcasting is done late in a chain of + // operations. The workaround is to broadcast the pointers early in the + // address calculation. A proper fix is complicated, but at least we can + // provide a better error message. + return rewriter.notifyMatchFailure( + op, "AtomicRMWOp expects a memref, not a memref of pointers"); + } + + auto dstMemref = ptr; + // Well, linalg structure op wouldn't support mixed tensor/buffer semantics + // any more in latest LLVM(triton LLVM dependency has involed this), so we + // need to convert tensor to buffer early. + auto dstOriType = cast(dstMemref.getType()); + MemRefType dstType = MemRefType::get(dstOriType.getShape(), dstOriType.getElementType()); + Value inputMemref = + rewriter.create(loc, dstType, val); + + // 2. handle the mask for the atomic op + // When the dsl do not pass the mask to this op like + // `tl.atomic_add(out_ptr0 + xindex, tmp2)`, it will create a constant mask + // for this op by default, which is not supported by maskAnalysis, so we + // need to handle this situation + // + // This logic come from semantic.py: + // + // if not mask: + // mask_ir = builder.get_int1(True) + // mask_ty = tl.int1 + // if ptr.type.is_block(): + // mask_ir = \ + // builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + // mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + // mask = tl.tensor(mask_ir, mask_ty) + // + // ... + // + // return ptr, val, mask + // + if (auto mask = op.getMask()) { + MaskState mstate; + auto constantMask = mask.getDefiningOp(); + if (!constantMask) { + auto isContMask = mstate.parse(mask, loc, rewriter); + + if (isContMask.failed()) { + return rewriter.notifyMatchFailure( + op, "Cannot lower continuous masked loads"); + } + dstMemref = mstate.getSubview(ptr, loc, rewriter); + inputMemref = mstate.getSubview(inputMemref, loc, rewriter); + } else { + if (!isConstantMaskTrue(mask)) { + rewriter.eraseOp(op); + return success(); + } + } + } + + // create element-wise map + int64_t rank = type.getRank(); + SmallVector inputDims; + auto context = rewriter.getContext(); + + for (int i = 0; i < rank; i++) { + inputDims.push_back(getAffineDimExpr(i, context)); + } + + SmallVector indexingMaps; + // As mask has been erased for now + // the number of input must be 2 + // the input memref is also the output memref + // Thus, there are a total of three inputs and outputs. + // so here we have 3 map to create + for (int i = 0; i < 3; i++) { + indexingMaps.push_back(AffineMap::get(rank, 0, inputDims, context)); + } + + auto linalgOp = rewriter.create( + loc, /* operands */ ValueRange{dstMemref, inputMemref}, + ValueRange{dstMemref}, indexingMaps, + mlir::ConverterUtils::getNParallelLoopsAttrs(rank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { + Value opResult = createAtomicBinaryOps(nestedBuilder, nestedLoc, op, + type.getElementType(), + blockArgs[0], blockArgs[1]); + nestedBuilder.create(nestedLoc, opResult); + }); + + // "library_call" + // indicating the actual semantic of this op + // TODO: If the hardware support the MemSemantic/MemSyncScope + // We pass them down + // otherwise they need to be deleted + const StringRef genericAtomicRMW = "GenericAtomicRMW"; + const StringRef memSemantic = "MemSemantic"; + const StringRef memSyncScope = "MemSyncScope"; + linalgOp->setAttr(genericAtomicRMW, + rewriter.getStringAttr(stringifyEnum(op.getAtomicRmwOp()))); + linalgOp->setAttr(memSemantic, + rewriter.getStringAttr(stringifyEnum(op.getSem()))); + linalgOp->setAttr(memSyncScope, + rewriter.getStringAttr(stringifyEnum(op.getScope()))); + + // Mark atomic_and/or/xor specially which need software simulation in terms + // of backend restriction + if (softwareAtomicKinds.contains(op.getAtomicRmwOp())) + linalgOp->setAttr("Software", rewriter.getUnitAttr()); + + + // tt.atomicRMW op has two part of feature + // 1. load the old data at the ptr + // 2. atomically store the data on ub to the ptr + // at the same time it perform the action it has been assigned + // So we lower this op to load + atomically store + // + // The first part is not necessary when the returned value of atomic op + // is not used, it will be deleted cause it's meaningless + // Here, we preemptively determine whether it will be used + // and decide whether it is necessary to create the load process based on + // this assessment. + // + // logic of handling is copied + // TODO: decoupling the logic of load, put it in the Utils + if (!op.getResult().use_empty()) { + auto tensorType = + RankedTensorType::get(type.getShape(), type.getElementType()); + auto alloc = rewriter.create( + loc, MemRefType::get(type.getShape(), type.getElementType())); + + // For the return value, don't need to care about mask for now + // this op don't support other, so we best not fill it + rewriter.create(loc, ptr, alloc); + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + } else { + rewriter.eraseOp(op); + } + return success(); +} + +LogicalResult +AtomicCASConverter::matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // If the result of AtomicCASOp is not used, we don't need to load the old + // data stored at the ptr + auto ptr = adaptor.getPtr(); + auto cmp = op.getCmp(); + auto val = op.getVal(); + auto loc = op.getLoc(); + + auto resType = dyn_cast(op.getResult().getType()); + if (!resType) { + return rewriter.notifyMatchFailure( + op, "atomicCASConverter: scalar will be handled by " + "ScalarAtomicCASCanonicalizer"); + } + + // 1. Simple case where no mask is used. + auto type = dyn_cast(ptr.getType()); + if (!type) { + // Seen when implicit broadcasting is done late in a chain of + // operations. The workaround is to broadcast the pointers early in the + // address calculation. A proper fix is complicated, but at least we can + // provide a better error message. + return rewriter.notifyMatchFailure( + op, "AtomicCASOp expects a memref, not a memref of pointers"); + } + + auto dstMemref = ptr; + // Well, linalg structure op wouldn't support mixed tensor/buffer semantics + // any more in latest LLVM(triton LLVM dependency has involed this), so we + // need to convert tensor to buffer early. + auto dstOriType = cast(dstMemref.getType()); + MemRefType dstType = MemRefType::get(dstOriType.getShape(), dstOriType.getElementType()); + Value inputMemref = + rewriter.create(loc, dstType, val); + + Value cmpMemref = + rewriter.create(loc, dstType, cmp); + + // create element-wise map + int64_t rank = type.getRank(); + SmallVector inputDims; + auto context = rewriter.getContext(); + + for (int i = 0; i < rank; i++) { + inputDims.push_back(getAffineDimExpr(i, context)); + } + + SmallVector indexingMaps; + // As mask has been erased for now + // the number of input must be 2 + // the input memref is also the output memref + // Thus, there are a total of four inputs and outputs. + // so here we have 4 map to create + for (int i = 0; i < 4; i++) { // 4: 3 input and 1 output + indexingMaps.push_back(AffineMap::get(rank, 0, inputDims, context)); + } + + auto linalgOp = rewriter.create( + loc, ValueRange{dstMemref, cmpMemref, inputMemref}, + mlir::ValueRange{dstMemref}, indexingMaps, + mlir::ConverterUtils::getNParallelLoopsAttrs(rank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { + Value lhs = blockArgs[0]; + Value rhs = blockArgs[1]; + Value setValue = blockArgs[2]; + Value cond; + if (mlir::isa(lhs.getType())) { + cond = nestedBuilder.create( + nestedLoc, arith::CmpFPredicate::UEQ, lhs, rhs); + } else { + cond = nestedBuilder.create( + nestedLoc, arith::CmpIPredicate::eq, lhs, rhs); + } + auto ifOp = nestedBuilder.create( + nestedLoc, TypeRange{setValue.getType()}, cond, true); + { + OpBuilder::InsertionGuard guard(nestedBuilder); + nestedBuilder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); + nestedBuilder.create(nestedLoc, setValue); + } + { + OpBuilder::InsertionGuard guard(nestedBuilder); + nestedBuilder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); + nestedBuilder.create(nestedLoc, lhs); + } + nestedBuilder.setInsertionPointToEnd(nestedBuilder.getBlock()); + nestedBuilder.create(nestedLoc, + ifOp.getResult(0)); + }); + + const StringRef genericAtomicRMW = "GenericAtomicRMW"; + const StringRef memSemantic = "MemSemantic"; + const StringRef memSyncScope = "MemSyncScope"; + auto attr = mlir::StringAttr::get(context, "cas"); + + linalgOp->setAttr(genericAtomicRMW, attr); + linalgOp->setAttr(memSemantic, + rewriter.getStringAttr(stringifyEnum(op.getSem()))); + linalgOp->setAttr(memSyncScope, + rewriter.getStringAttr(stringifyEnum(op.getScope()))); + + linalgOp->setAttr("Software", rewriter.getUnitAttr()); + + // tt.atomicRMW op has two part of feature + // 1. load the old data at the ptr + // 2. atomically store the data on ub to the ptr + // at the same time it perform the action it has been assigned + // So we lower this op to load + atomically store + // + // The first part is not necessary when the returned value of atomic op + // is not used, it will be deleted cause it's meaningless + // Here, we preemptively determine whether it will be used + // and decide whether it is necessary to create the load process based on + // this assessment. + // + // logic of handling is copied + if (!op.getResult().use_empty()) { + auto tensorType = + RankedTensorType::get(type.getShape(), type.getElementType()); + auto alloc = rewriter.create( + loc, MemRefType::get(type.getShape(), type.getElementType())); + + // For the return value, don't need to care about mask for now + // this op don't support other, so we best not fill it + rewriter.create(loc, ptr, alloc); + Value tensor = rewriter.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, tensor); + } else { + rewriter.eraseOp(op); + } + return success(); +} + +LogicalResult +ScalarStoreCanonicalizer::matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const { + if (!op.getValue().getType().isIntOrIndexOrFloat()) { + return rewriter.notifyMatchFailure( + op, "ScalarStoreCanonicalizer handles scalar store scene!"); + } + auto ptr = op.getPtr(); + auto mask = op.getMask(); + auto value = op.getValue(); + if (mask) { + rewriter.replaceOpWithNewOp(op, mask, + [&](OpBuilder &b, Location loc) { + b.create( + loc, ptr, value, op.getCache(), op.getEvict()); + b.create(loc); + }); + return success(); + } + + auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType()); + auto ptrSplat = rewriter.create(op.getLoc(), ptrTy, ptr); + auto valTy = RankedTensorType::get({(int64_t)1}, value.getType()); + auto valSplat = + rewriter.create(op.getLoc(), valTy, value); + auto newStoreOp = rewriter.create( + op.getLoc(), ptrSplat, valSplat, op.getCache(), op.getEvict()); + rewriter.replaceOp(op, newStoreOp); + return success(); +} + +LogicalResult +ScalarAtomicRMWCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const { + if (!op.getVal().getType().isIntOrIndexOrFloat()) { + return rewriter.notifyMatchFailure( + op, "ScalarAtomicRMWCanonicalizer handles scalar atomic rmw op scene!"); + } + + auto ptr = op.getPtr(); + auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType()); + auto ptrSplat = rewriter.create(op.getLoc(), ptrTy, ptr); + auto valTy = RankedTensorType::get({(int64_t)1}, op.getVal().getType()); + auto valSplat = + rewriter.create(op.getLoc(), valTy, op.getVal()); + auto maskTy = RankedTensorType::get({(int64_t)1}, op.getMask().getType()); + auto maskSplat = + rewriter.create(op.getLoc(), maskTy, op.getMask()); + + auto newAtomicOp = rewriter.create( + op.getLoc(), valTy, op.getAtomicRmwOp(), ptrSplat, valSplat, maskSplat, + op.getSem(), op.getScope()); + auto idxZero = + rewriter.create(op.getLoc(), rewriter.getIndexAttr(0)); + rewriter.replaceOpWithNewOp(op, newAtomicOp, ValueRange({idxZero})); + return success(); +} + +LogicalResult +ScalarAtomicCASCanonicalizer::matchAndRewrite(triton::AtomicCASOp op, + PatternRewriter &rewriter) const { + if (!op.getVal().getType().isIntOrIndexOrFloat() && + !op.getCmp().getType().isIntOrIndexOrFloat()) { + return rewriter.notifyMatchFailure( + op, "ScalarAtomicCASCanonicalizer handles scalar atomic cas op scene!"); + } + + auto ptr = op.getPtr(); + auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType()); + auto ptrSplat = rewriter.create(op.getLoc(), ptrTy, ptr); + auto cmpTy = RankedTensorType::get({(int64_t)1}, op.getCmp().getType()); + auto cmpSplat = + rewriter.create(op.getLoc(), cmpTy, op.getCmp()); + auto valTy = RankedTensorType::get({(int64_t)1}, op.getVal().getType()); + auto valSplat = + rewriter.create(op.getLoc(), valTy, op.getVal()); + + auto newAtomicOp = rewriter.create( + op.getLoc(), valTy, ptrSplat, cmpSplat, valSplat, op.getSem(), + op.getScope()); + auto idxZero = + rewriter.create(op.getLoc(), rewriter.getIndexAttr(0)); + rewriter.replaceOpWithNewOp(op, newAtomicOp, ValueRange({idxZero})); + return success(); +} + +// The atomic max op with float input will be devided into +// two atomic max ops with integer input +// One handles the part of the tensor greater than zero +// the other deals with the part less than zero +// It will lead to maskAnalysis failure +// So here we need to revert the procedures in semantics.py +// The triton IR is like +// +// %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x256xf32> +// %1 = tt.bitcast %value : tensor<1x256xf32> -> tensor<1x256xi32> +// %2 = tt.bitcast %ptr : tensor<1x256x!tt.ptr> -> +// tensor<1x256x!tt.ptr> %3 = arith.cmpf oge, %1, %cst_0 %4 = arith.cmpf +// olt, %1, %cst_0 %5 = arith.andi %8, %3 %6 = tt.atomic_rmw max, acq_rel, gpu, +// %2, %1, %5 : +// (tensor<1x256x!tt.ptr>, tensor<1x256xi32>, tensor<1x256xi1>) -> +// tensor<1x256xi32> +// %7 = arith.andi %8, %4 +// %8 = tt.atomic_rmw umin, acq_rel, gpu, %2, %1, %7 : +// (tensor<1x256x!tt.ptr>, tensor<1x256xi32>, tensor<1x256xi1>) -> +// tensor<1x256xi32> +// +// it's hard to handle and meaningless complicated for our device +// so we revert it to +// %0 = tt.atomic_rmw max, acq_rel, gpu, %23, %21, %8 : +// (tensor<1x256x!tt.ptr>, tensor<1x256xf32>, tensor<1x256xi1>) -> +// tensor<1x256xf32> +LogicalResult +AtomicMaxMinCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op, + PatternRewriter &rewriter) const { + // Revert the op to its original form + auto ptrBitcastOp = op.getPtr().getDefiningOp(); + auto valueBitcastOp = op.getVal().getDefiningOp(); + if (!ptrBitcastOp || !valueBitcastOp) { + return failure(); + } + + // We only need to handle the op when the element type is float + auto elementType = + dyn_cast(valueBitcastOp.getSrc().getType()).getElementType(); + if (!isa(elementType)) { + return failure(); + } + + auto rmwOp = op.getAtomicRmwOp(); + // here we know that atomic UMAX/UMIN + // is created by special logic of triton right now + // so we can simply delete it + if (rmwOp == triton::RMWOp::UMAX || rmwOp == triton::RMWOp::UMIN) { + // if the return value of op is used, we can't simply erase it + if (op.getResult().use_empty()) { + rewriter.eraseOp(op); + return success(); + } + return failure(); + } + + if (rmwOp != triton::RMWOp::MAX && rmwOp != triton::RMWOp::MIN) { + return failure(); + } + + // 1. Though semantic interpreter will generate full true tensor as original + // mask if atomicrmwOp don't have it, above float devision process will also + // generate positive and negative comparison mask, which will cause to fold + // true mask. + // 2. While if atomicrmwOp has original mask, there exists andiop between + // original mask and positive/negative comparison mask + // + // Here wanna extract original mask + Value originalMask = op.getMask(); + if (auto andOp = originalMask.getDefiningOp()) + // LHS is convention in semantic interpreter + originalMask = andOp.getLhs(); + else if (auto cmpOp = originalMask.getDefiningOp()) { + if (cmpOp.getPredicate() != mlir::arith::CmpFPredicate::OGE || + !matchPattern(cmpOp.getRhs(), + /*positive float zero matcher*/ m_PosZeroFloat())) + // Here recheck frontend interpreter generation in no manual mask state + return op->emitError("Illegal mask for atomicrmwOp of float type"); + // Restore original true mask + originalMask = rewriter.create( + op->getLoc(), + /*typed attr*/ DenseElementsAttr::get( + cast(originalMask.getType()), true)); + } else + return op->emitError("Illegal mask for atomicrmwOp of float type"); + + auto originAtomicOp = rewriter.create( + op.getLoc(), valueBitcastOp.getSrc().getType(), op.getAtomicRmwOp(), + ptrBitcastOp.getSrc(), valueBitcastOp.getSrc(), originalMask, op.getSem(), + op.getScope()); + + // if the return value of op is used + // we need to handle its usage + // In semantic.py, if the atomic Max/Min with float input is used + // It will use select + bitcast to get float value + // so here we need to revert it too + // + // For example: + // %0 = tt.atomic_rmw max, acq_rel, gpu, %gm, %input, %mask1 : + // (tensor<32x!tt.ptr>... %1 = tt.atomic_rmw umin, acq_rel, gpu, %gm, + // %input, %mask2 : (tensor<32x!tt.ptr>... %2 = arith.select + // %devidedMask, %0, %1 : tensor<32xi1>, tensor<32xi32> %3 = tt.bitcast %2 : + // tensor<32xi32> -> tensor<32xf32> tt.store %outputMemref, %3 : + // tensor<32x!tt.ptr> + // + // will be revert to: + // %0 = tt.atomic_rmw max, acq_rel, gpu, %gm, %input, %mask : + // (tensor<32x!tt.ptr>... tt.store %outputMemref, %0 : + // tensor<32x!tt.ptr> + // + if (!op.getResult().use_empty()) { + for (OpOperand &use : op->getUses()) { + auto selectOp = dyn_cast(use.getOwner()); + if (!selectOp) + continue; + + for (OpOperand &selectUse : selectOp->getUses()) { + if (auto bitcastOp = + dyn_cast(selectUse.getOwner())) { + bitcastOp.getResult().replaceAllUsesWith(originAtomicOp); + } + } + } + rewriter.replaceOp(op, originAtomicOp); + } else { + rewriter.eraseOp(op); + } + + return success(); +} + +StoreConverter::StoreConverter(MLIRContext *context) + : OpConversionPattern(context) {} + +LogicalResult +StoreConverter::matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // triton store op basic + auto mask = op.getMask(); + auto loc = op.getLoc(); + auto ptr = adaptor.getPtr(); + auto val = adaptor.getValue(); + + // 1. boundary size check + auto boundaryCheck = op.getBoundaryCheck(); + if (!boundaryCheck.empty()) { + auto boundarySizes = mlir::ConverterUtils::getBoundarySizes( + boundaryCheck, /*remapped*/ ptr, loc, rewriter); + auto srcSlice = mlir::ConverterUtils::makeExtractSliceOp(val, boundarySizes, loc, rewriter); + auto dstSubview = mlir::ConverterUtils::makeSubViewOp(ptr, boundarySizes, loc, rewriter); + auto storeOp = + rewriter.create( + loc, srcSlice, dstSubview); + storeOp.setWritable(true); + rewriter.eraseOp(op); + return success(); + } + + // 2. Simple load with no mask + if (!mask) { + auto storeOp = rewriter.create( + loc, val, ptr); + storeOp.setWritable(true); + rewriter.eraseOp(op); + return success(); + } + + // 3. Continuous masked stores. + // Analyze the mask operand to determine at runtime the size of the data we + // are moving. + MaskState mstate; + auto isContMask = mstate.parse(mask, loc, rewriter); + + if (isContMask.failed()) { + return failure(); + } + LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; }); + auto srcSlice = mstate.getExtractSlice(val, loc, rewriter); + auto dstSubview = mstate.getSubview(ptr, loc, rewriter); + auto storeOp = rewriter.create( + loc, srcSlice, dstSubview); + storeOp.setWritable(true); + rewriter.eraseOp(op); + return success(); +} +} // namespace LoadStoreConverter \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp new file mode 100644 index 000000000..17f231b09 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/MaskAnalysis.cpp @@ -0,0 +1,517 @@ +#include "TritonToLinalg/MaskAnalysis.h" +#include "Utils/Utils.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include +#include + +#define DEBUG_TYPE "mask-analysis" + +namespace mlir { + +namespace triton { + +LogicalResult MaskState::parse(Value operand, const Location &loc, + OpBuilder &builder) { + if (isa(operand.getType())) { + return parseIntScalar(operand, loc, builder); + } + + auto definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + + LLVM_DEBUG({ + llvm::dbgs() << "[MaskState]==> parse op\n" + << *definingOp << "\n[MaskState]<==\n"; + }); + return TypeSwitch(definingOp) + .Case( + [&](auto op) { return this->parseConstant(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseAdd(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseAnd(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseCmp(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseMakeRange(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseBroadcast(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseSplat(op, loc, builder); }) + .Case( + [&](auto op) { return this->parseExpandDims(op, loc, builder); }) + .Case( + [&](auto op) { return this->parse(op.getIn(), loc, builder); }) + .Case( + [&](auto op) { return this->parseDiv(op, loc, builder); }) + .Default([&](Operation *op) { return failure(); }); +} + +// extractSlice +tensor::ExtractSliceOp MaskState::getExtractSlice(Value source, + const Location &loc, + OpBuilder &builder) const { + auto sourceRType = cast(source.getType()); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + + auto dstRType = tensor::ExtractSliceOp::inferResultType(sourceRType, offsets, + dims, strides); + return builder.create(loc, dstRType, source, offsets, + dims, strides); +} + +tensor::InsertSliceOp MaskState::getInsertSlice(Value source, Value dest, + const Location &loc, + OpBuilder &builder) const { + SmallVector strides(getRank(), builder.getIndexAttr(1)); + return builder.create(loc, source, dest, offsets, dims, + strides); +} + +memref::SubViewOp MaskState::getSubview(Value source, const Location &loc, + OpBuilder &builder) const { + auto sourceType = cast(source.getType()); + SmallVector strides(getRank(), builder.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(sourceType, offsets, dims, strides); + return builder.create(loc, cast(dstType), + source, offsets, dims, strides); +} + +static memref::SubViewOp createSubview(Value src, const Location &loc, + OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto srcType = cast(src.getType()); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return builder.create(loc, cast(dstType), src, + offsets, sizes, strides); +} + +LogicalResult MaskState::addStateScalar(const MaskState &state, + const OpFoldResult scalar, + const Location &loc, + OpBuilder &builder) { + start = addOpFoldResult(state.start, scalar, loc, builder); + end = addOpFoldResult(state.end, scalar, loc, builder); + dims = state.dims; + offsets = state.offsets; + return success(); +} + +LogicalResult MaskState::addStates(const MaskState &lhsState, + const MaskState &rhsState, + const Location &loc, OpBuilder &builder) { + if (lhsState.scalar && rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) << "Unexpected case where both lhs and rhs are scalars"; + return failure(); + } + if (!lhsState.scalar && !rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) + << "Unsupported scenario where neither lhs nor rhs is a scalar"; + return failure(); + } + + if (lhsState.scalar) { + return addStateScalar(rhsState, lhsState.scalar, loc, builder); + } else { + return addStateScalar(lhsState, rhsState.scalar, loc, builder); + } +} + +LogicalResult MaskState::divStateScalar(const MaskState &state, + const OpFoldResult scalar, + const Location &loc, + OpBuilder &builder) { + start = divOpFoldResult(state.start, scalar, loc, builder); + end = divOpFoldResult(state.end, scalar, loc, builder); + dims = state.dims; + offsets = state.offsets; + return success(); +} + +LogicalResult MaskState::divStates(const MaskState &lhsState, + const MaskState &rhsState, + const Location &loc, OpBuilder &builder) { + if (!lhsState.scalar && rhsState.scalar) { + if (isZeroIndex(rhsState.scalar)) { + InFlightDiagnostic diag = + emitError(loc) + << "Unsupported scenario where rhs is zero constant in divide!"; + return failure(); + } + + return divStateScalar(lhsState, rhsState.scalar, loc, builder); + } + + InFlightDiagnostic diag = emitError(loc) + << "Supported scenario where only rhs is a scalar"; + return failure(); +} + +LogicalResult MaskState::minStates(const MaskState &lhsState, + const MaskState &rhsState, + const Location &loc, OpBuilder &builder) { + if (lhsState.getRank() != rhsState.getRank()) { + InFlightDiagnostic diag = + emitError(loc) + << "Unexpected case where lhs and rhs have different ranks"; + return failure(); + } + + for (uint32_t i = 0; i < lhsState.getRank(); i++) { + auto lhsOffset = lhsState.offsets[i]; + auto rhsOffset = rhsState.offsets[i]; + auto newOffset = maxOpFoldResult(lhsOffset, rhsOffset, loc, builder); + auto lhsDim = lhsState.dims[i]; + auto rhsDim = rhsState.dims[i]; + auto lhsEnd = addOpFoldResult(lhsOffset, lhsDim, loc, builder); + auto rhsEnd = addOpFoldResult(rhsOffset, rhsDim, loc, builder); + auto newEnd = minOpFoldResult(lhsEnd, rhsEnd, loc, builder); + auto newDim = subOpFoldResult(newEnd, newOffset, loc, builder); + + offsets.push_back(newOffset); + dims.push_back(newDim); + } + return success(); +} + +// Helper func for MaskState::parse() +LogicalResult MaskState::parseConstant(arith::ConstantOp constOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (isa(constOp.getValue())) { + auto attr = cast(constOp.getValue()); + auto elementType = attr.getElementType(); + assert(attr.isSplat() && isa(elementType) && + "All elements must share a single integer constant value"); + this->scalar = builder.getIndexAttr( + attr.getSplatValue().getValue().getSExtValue()); + } else { + auto value = cast(constOp.getValue()).getInt(); + this->scalar = builder.getIndexAttr(value); + } + return success(); +} + +// parseIntScalar +LogicalResult MaskState::parseIntScalar(Value scalar, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + this->scalar = getOpFoldResultOfLayoutInfo(scalar, builder); + return success(); +} + +LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + MaskState lhsState; + if (failed(lhsState.parse(addOp.getLhs(), loc, builder))) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(addOp.getRhs(), loc, builder))) { + return failure(); + } + return this->addStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseDiv(arith::DivSIOp divOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + MaskState lhsState; + if (failed(lhsState.parse(divOp.getLhs(), loc, builder))) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(divOp.getRhs(), loc, builder))) { + return failure(); + } + return this->divStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + MaskState lhsState; + if (failed(lhsState.parse(andOp.getLhs(), loc, builder)) || + !lhsState.isMask()) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(andOp.getRhs(), loc, builder)) || + !rhsState.isMask()) { + return failure(); + } + + if (!lhsState.isMask() && !rhsState.isMask()) { + return failure(); + } + + // Only support both lhs and rhs satisfy `isMask` condition + return this->minStates(lhsState, rhsState, loc, builder); +} + +LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + // Only support <, <=, >=, = + if (cmpOp.getPredicate() != arith::CmpIPredicate::slt && + cmpOp.getPredicate() != arith::CmpIPredicate::sle && + cmpOp.getPredicate() != arith::CmpIPredicate::sge && + cmpOp.getPredicate() != arith::CmpIPredicate::eq) { + LLVM_DEBUG({ llvm::dbgs() << "Unsupported cmpi predicate\n"; }); + return failure(); + } + MaskState lhsState; + if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder))) { + return failure(); + } + + MaskState rhsState; + if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder))) { + return failure(); + } + + if (!(!lhsState.scalar && rhsState.scalar)) { + InFlightDiagnostic diag = emitError(loc) + << "[MaskState] Unsupported cmpi scenario"; + return failure(); + } + + int32_t cmpDim = -1; + for (int32_t i = 0; i < lhsState.getRank(); i++) { + auto constDimLength = getConstantIntValue(lhsState.dims[i]); + if (!constDimLength || constDimLength.value() != 1) { + if (cmpDim != -1) { + InFlightDiagnostic diag = emitError(loc) + << "Unsupported cmpi with more than one " + "dimension with size larger than 1"; + return failure(); + } + cmpDim = i; + } + } + + assert(cmpDim != -1 && + "Unexpected case where no dimension has size larger than 1"); + + this->offsets = lhsState.offsets; + this->dims = lhsState.dims; + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::slt: { + auto realBound = + maxOpFoldResult(lhsState.start, rhsState.scalar, loc, builder); + auto newEnd = minOpFoldResult(lhsState.end, realBound, loc, builder); + auto newDim = subOpFoldResult(newEnd, lhsState.start, loc, builder); + + this->dims[cmpDim] = newDim; + break; + } + case arith::CmpIPredicate::sle: { + // lhs <= rhs <=> lhs < rhs + 1 + auto rhsPlusOne = addOpFoldResult(rhsState.scalar, builder.getIndexAttr(1), loc, builder); + auto realBound = maxOpFoldResult(lhsState.start, rhsPlusOne, loc, builder); + auto newEnd = minOpFoldResult(lhsState.end, realBound, loc, builder); + auto newDim = subOpFoldResult(newEnd, lhsState.start, loc, builder); + + this->dims[cmpDim] = newDim; + break; + } + case arith::CmpIPredicate::sge: { + auto realBound = + maxOpFoldResult(lhsState.start, rhsState.scalar, loc, builder); + auto newStart = minOpFoldResult(lhsState.end, realBound, loc, builder); + auto newOffset = subOpFoldResult(newStart, lhsState.start, loc, builder); + auto newDim = subOpFoldResult(lhsState.end, newStart, loc, builder); + + this->offsets[cmpDim] = newOffset; + this->dims[cmpDim] = newDim; + break; + } + case arith::CmpIPredicate::eq: { + auto newOffset = + subOpFoldResult(rhsState.scalar, lhsState.start, loc, builder); + auto newDim = builder.getIndexAttr(1); + + this->offsets[cmpDim] = newOffset; + this->dims[cmpDim] = newDim; + break; + } + default: + return failure(); + } + return success(); +} + +LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto shape = cast(rangeOp.getType()).getShape(); + auto start = rangeOp.getStart(); + auto end = rangeOp.getEnd(); + auto stride = (end - start + shape[0] - 1) / shape[0]; + + if (stride != 1) { + InFlightDiagnostic diag = + emitError(loc) + << "stride must be 1 for make_range whose result is used " + "as load or store masks"; + return failure(); + } + + this->start = builder.getIndexAttr(start); + this->end = builder.getIndexAttr(end); + this->dims.push_back(builder.getIndexAttr(shape[0])); + this->offsets.push_back(builder.getIndexAttr(0)); + return success(); +} + +LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + auto src = broadcastOp.getSrc(); + auto dst = broadcastOp.getResult(); + assert(isa(src.getType()) && + "input to tt.broadcast should be a tensor"); + + auto srcShape = cast(src.getType()).getShape(); + auto dstShape = cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && + "rank of source and destination should match"); + + if (failed(parse(src, loc, builder))) { + return failure(); + } + for (size_t i = 0; i < srcShape.size(); i++) { + if (srcShape[i] == dstShape[i]) + continue; + else if (srcShape[i] < dstShape[i]) + this->dims[i] = builder.getIndexAttr(dstShape[i]); + else + llvm_unreachable("unexpected dimensions used in broadcast"); + } + return success(); +} + +LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, + const Location &loc, OpBuilder &builder) { + assert(this->isEmpty()); + + auto src = splatOp.getSrc(); + auto dst = splatOp.getResult(); + auto dstShape = cast(dst.getType()).getShape(); + + if (!isa(src.getType())) { + InFlightDiagnostic diag = + emitError(loc) + << "splat source must be an integer scalar for load/store masks"; + return failure(); + } + + if (failed(this->parse(src, loc, builder))) + return failure(); + + auto splatAsMask = [&](Operation *userOp) -> bool { + return TypeSwitch(userOp) + .Case([&](arith::AndIOp andOp) { return true; }) + .Case([&](arith::SelectOp selectOp) { + return selectOp.getCondition() == dst; + }) + .Case( + [&](triton::LoadOp loadOp) { return loadOp.getMask() == dst; }) + .Case( + [&](triton::StoreOp storeOp) { return storeOp.getMask() == dst; }) + .Default([&](Operation *op) { return false; }); + }; + + if (src.getType().isInteger(1) && !splatOp->use_empty() && + llvm::all_of(splatOp->getUsers(), splatAsMask)) { + for (auto s : dstShape) { + auto currentDim = + mulOpFoldResult(builder.getIndexAttr(s), this->scalar, loc, builder); + this->dims.push_back(currentDim); + this->offsets.push_back(builder.getIndexAttr(0)); + } + + this->scalar = nullptr; + return success(); + } + + for (auto s : dstShape) { + this->dims.push_back(builder.getIndexAttr(s)); + this->offsets.push_back(builder.getIndexAttr(0)); + } + return success(); +} + +LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp, + const Location &loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + if (failed(this->parse(expandDimsOp.getSrc(), loc, builder))) { + return failure(); + } + + auto dstShape = + cast(expandDimsOp.getResult().getType()).getShape(); + auto axis = expandDimsOp.getAxis(); + assert(dstShape[axis] == 1 && + "Expect changed dimention to be 1 in expand_dims"); + this->dims.insert(this->dims.begin() + axis, builder.getIndexAttr(1)); + this->offsets.insert(this->offsets.begin() + axis, builder.getIndexAttr(0)); + + return success(); +} + +void MaskState::eraseInsertedOps(Operation *rawOp, PatternRewriter &rewriter) { + auto moduleOp = rawOp->getParentOfType(); + SmallVector worklist; + moduleOp->walk([&](Operation *op) { + if (isOpTriviallyDead(op)) + worklist.push_back(op); + }); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (!isOpTriviallyDead(op)) + continue; + for (Value value : op->getOperands()) { + if (auto defOp = value.getDefiningOp()) + worklist.push_back(defOp); + } + LLVM_DEBUG({ + llvm::dbgs() << "[MaskState]==> inserted op: \n" + << *op << "\n[MaskState]<== is removed\n"; + }); + rewriter.eraseOp(op); + } +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp new file mode 100644 index 000000000..6dbf3ac75 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp @@ -0,0 +1,1715 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation, Meta Platforms. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "TritonToLinalg/TritonOpConverter.h" +#include "TritonToLinalg/TritonToLinalgPass.h" +#include "TritonToLinalg/BlockPtrAnalysis.h" +#include "TritonToLinalg/MaskAnalysis.h" +#include "Utils/Utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/ValueRange.h" + +namespace TTOpConverters { +using namespace mlir; +using namespace triton; + +LogicalResult +BitcastConverter::matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value result; + if (auto resPointerType = dyn_cast(op.getType())) { + // TODO: use typeconverter + auto srcPointerType = cast(op.getSrc().getType()); + auto resType = MemRefType::get({ShapedType::kDynamic}, resPointerType.getPointeeType()); + // Handling special case + // %0 = tt.bitcast %arg0 {MixUse} : !tt.ptr -> !tt.ptr + if (isa(adaptor.getSrc()) && + srcPointerType.getPointeeType() == rewriter.getIntegerType(1) && + resPointerType.getPointeeType() == rewriter.getIntegerType(8)) { + rewriter.modifyOpInPlace(op, [&]() { + op->setAttr("MetaUse", rewriter.getUnitAttr()); + }); + return success(); + } + result = rewriter.create( + op.getLoc(), resType, adaptor.getSrc()); + } else { + result = rewriter.create( + op.getLoc(), op.getType(), adaptor.getSrc()); + } + rewriter.replaceOp(op, result); + return success(); +} + +LogicalResult +TransposeConverter::matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto res = ConverterUtils::getTransposedValue(src, op.getLoc(), rewriter, + op.getOrder()); + rewriter.replaceOp(op, res); + return success(); +} + +LogicalResult +YieldConverter::matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); +} + +LogicalResult +AdvanceConverter::matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + llvm::SmallDenseMap known; + BlockDataParser::rewriteAdvanceOp(op, rewriter, known); + return success(); +} + +// ToDo: +// 1. Refactor MakeTensorPtrConverter and AdvanceConverter with +// memref::ReinterpretCastOp and memref::SubViewOp. +// Use recast to describe full shape of tensor, and use subview to represent +// current block tensor. +LogicalResult MakeTensorPtrConverter::matchAndRewrite( + triton::MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + llvm::SmallDenseMap known; + BlockDataParser::rewriteMakeTensorPtrOp(op, adaptor.getBase(), rewriter, known); + return success(); +} + +LogicalResult PreciseDivConverter::matchAndRewrite( + triton::PreciseDivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value opa = op.getX(); + Value opb = op.getY(); + auto loc = op.getLoc(); + + auto resType = dyn_cast(op.getResult().getType()); + auto divOp = rewriter.create(loc, resType, opa, opb); + + rewriter.replaceOp(op, divOp); + return success(); +} + +/* + * Move tt.bitcast to a previous location if tt.bitcast is not directly applied + * on function arguments + */ +LogicalResult +BitcastCanonicalizer::matchAndRewrite(triton::BitcastOp bitcastOp, + PatternRewriter &rewriter) const { + Value castSrc = bitcastOp.getSrc(); + Value castRes = bitcastOp.getResult(); + Type castSrcTy = castSrc.getType(); + Type castSrcPtrTy = isa(castSrcTy) + ? cast(castSrcTy).getElementType() + : castSrcTy; + if (!isa(castSrcPtrTy)) + return failure(); + + auto origBitwidth = getPointeeBitWidth(castSrc.getType()); + auto castBitwidth = getPointeeBitWidth(castRes.getType()); + + if (origBitwidth == 1) + origBitwidth = 8; + if (castBitwidth == 1) + castBitwidth = 8; + if (origBitwidth != castBitwidth) { + bitcastOp.emitError() << "Casting pointers with unmatched bitwidth!\n"; + return failure(); + } + + Operation *beforeCastOp = castSrc.getDefiningOp(); + if (beforeCastOp == nullptr) { + return failure(); + } + + auto newRes = + TypeSwitch>(beforeCastOp) + // before: addptr - bitcast - load/store + // after: bitcast - addptr - load/store + .Case([&](triton::AddPtrOp addptrOp) { + auto newCastOp = rewriter.create( + bitcastOp.getLoc(), castRes.getType(), addptrOp.getPtr()); + return rewriter.create( + bitcastOp.getLoc(), castRes.getType(), newCastOp.getResult(), + addptrOp.getOffset()); + }) + .Case([&](triton::SplatOp splatOp) { + Type newCastSrcTy = + cast(castRes.getType()).getElementType(); + + Value splatSrc = splatOp.getSrc(); + Type splatSrcTy = splatSrc.getType(); + if (auto splatSrcTensorTy = dyn_cast(splatSrcTy)) + newCastSrcTy = + splatSrcTensorTy.cloneWith(std::nullopt, newCastSrcTy); + auto newCastOp = rewriter.create( + bitcastOp.getLoc(), newCastSrcTy, splatSrc); + return rewriter.create( + bitcastOp.getLoc(), castRes.getType(), newCastOp); + }) + // before: bitcast - bitcast + // after(fusion optimization): bitcast + .Case([&](triton::BitcastOp prevCastOp) { + return rewriter.create( + bitcastOp.getLoc(), castRes.getType(), prevCastOp.getSrc()); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(bitcastOp, + "Unknown bitcast pattern"); + }); + if (succeeded(newRes)) { + rewriter.replaceOp(bitcastOp, newRes.value()); + if (beforeCastOp->use_empty()) { + rewriter.eraseOp(beforeCastOp); + } + return success(); + } + return failure(); +} + +void rewriteUserWithNewOrder(mlir::OpOperand *use, PatternRewriter &rewriter, llvm::SmallVector &blkShapeI64, // 8: container size + mlir::Location &loc, llvm::ArrayRef &order, size_t &orderSize) +{ + Operation *user = use->getOwner(); + rewriter.setInsertionPointAfter(user); + if (auto loadOp = dyn_cast(user)) { + auto loadResTy = loadOp.getResult().getType(); + auto loadResShapedTy = cast(loadResTy); + auto newLoadTy = loadResShapedTy.cloneWith( + blkShapeI64, loadResShapedTy.getElementType()); + auto newLoadOp = rewriter.create( + loc, newLoadTy, loadOp->getOperands(), loadOp->getAttrs()); + newLoadOp->setAttr(ConverterUtils::GeneratedByMakeTensorPtrTAG, UnitAttr::get(rewriter.getContext())); + rewriter.replaceOp(loadOp, newLoadOp); + // load contiguous data then permute. thus the permute order is as + // follows. + SmallVector permuteOrder; // 8: container size + for (auto [i, v] : llvm::enumerate(order)) { + permuteOrder.push_back(orderSize - 1 - order[i]); + } + auto permuteOp = rewriter.create( + loc, newLoadOp.getResult(), + DenseI32ArrayAttr::get(loadOp.getContext(), permuteOrder)); + newLoadOp.getResult().replaceAllUsesExcept(permuteOp.getResult(), permuteOp); + } else if (auto storeOp = dyn_cast(user)) { + // permute to contiguous then store. thus the permute order is as follows. + SmallVector permuteOrder; // 8: container size + for (auto [i, v] : llvm::enumerate(order)) { + permuteOrder.push_back(order[orderSize - 1 - i]); + } + auto permuteOp = rewriter.create( + loc, storeOp.getValue(), + DenseI32ArrayAttr::get(storeOp.getContext(), permuteOrder)); + storeOp.getValue().replaceAllUsesExcept(permuteOp.getResult(), permuteOp); + auto newStoreOp = rewriter.create( + loc, storeOp.getPtr(), storeOp.getValue(), storeOp.getMask(), + storeOp.getBoundaryCheck(), storeOp.getCache(), storeOp.getEvict()); + rewriter.replaceOp(storeOp, newStoreOp); + } else if (auto advanceOp = dyn_cast(user)) { + auto advanceResPtrTy = + cast(advanceOp.getResult().getType()); + auto advanceResShapedTy = + cast(advanceResPtrTy.getPointeeType()); + auto newAdvanceResShapedTy = advanceResShapedTy.cloneWith( + blkShapeI64, advanceResShapedTy.getElementType()); + auto newAdvanceResPtrTy = triton::PointerType::get( + newAdvanceResShapedTy, advanceResPtrTy.getAddressSpace()); + auto advanceOffsets = advanceOp.getOffsets(); + llvm::SmallVector newAdvanceOffsets; // 8: container size + for (int i = orderSize - 1; i >= 0; i--) { + newAdvanceOffsets.push_back(advanceOffsets[order[i]]); + } + SmallVector resUses; + for (auto &use: advanceOp->getUses()) + resUses.push_back(&use); + auto newAdvanceOp = rewriter.create( + loc, newAdvanceResPtrTy, advanceOp.getPtr(), newAdvanceOffsets); + rewriter.replaceOp(advanceOp, newAdvanceOp); + for (auto resUse : resUses) + rewriteUserWithNewOrder(resUse, rewriter, blkShapeI64, loc, order, orderSize); + } else if (auto loopOp = dyn_cast(user)) { + auto initArg = use->get(); + auto iterArg = loopOp.getTiedLoopRegionIterArg(use); + auto resultValue = loopOp.getTiedLoopResult(use); + iterArg.setType(initArg.getType()); + resultValue.setType(initArg.getType()); + for (auto &argUse : iterArg.getUses()) + rewriteUserWithNewOrder(&argUse, rewriter, blkShapeI64, loc, order, orderSize); + for (auto &resUse : resultValue.getUses()) + rewriteUserWithNewOrder(&resUse, rewriter, blkShapeI64, loc, order, orderSize); + } else if (isa(user)) { + return; + } else { + llvm_unreachable("[MakeTensorPtrCanonicalizer] tt.make_tensor_ptr's result is " + "not used by load/store/advance op"); + } +} + +void markLoadUsers(mlir::OpOperand *use, PatternRewriter &rewriter) +{ + Operation *user = use->getOwner(); + if (auto loadOp = dyn_cast(user)) { + loadOp->setAttr(ConverterUtils::GeneratedByMakeTensorPtrTAG, UnitAttr::get(rewriter.getContext())); + } else if (auto storeOp = dyn_cast(user)) { + return; + } else if (auto advanceOp = dyn_cast(user)) { + SmallVector resUses; + for (auto &use: advanceOp->getUses()) + resUses.push_back(&use); + for (auto resUse : resUses) + markLoadUsers(resUse, rewriter); + } else if (auto loopOp = dyn_cast(user)) { + auto initArg = use->get(); + auto iterArg = loopOp.getTiedLoopRegionIterArg(use); + auto resultValue = loopOp.getTiedLoopResult(use); + iterArg.setType(initArg.getType()); + resultValue.setType(initArg.getType()); + for (auto &argUse : iterArg.getUses()) + markLoadUsers(&argUse, rewriter); + for (auto &resUse : resultValue.getUses()) + markLoadUsers(&resUse, rewriter); + } else if (isa(user)) { + return; + } else { + llvm_unreachable("[MakeTensorPtrCanonicalizer] tt.make_tensor_ptr's result is " + "not used by load/store/advance op"); + } +} + +LogicalResult +MakeTensorPtrCanonicalizer::matchAndRewrite(triton::MakeTensorPtrOp op, + PatternRewriter &rewriter) const { + auto order = op.getOrder(); + auto orderSize = order.size(); + if (orderSize == 1) { + return rewriter.notifyMatchFailure( + op, "make_tensor_ptr's order has single value."); + } + + bool isPermuted = false; + for (auto [first, second] : llvm::zip(order.slice(0, orderSize - 1), + order.slice(1, orderSize - 1))) { + if (first != second + 1) { + isPermuted = true; + break; + } + } + + auto loc = op.getLoc(); + auto base = op.getBase(); + auto shape = op.getShape(); + auto strides = op.getStrides(); + auto offsets = op.getOffsets(); + auto result = op.getResult(); + SmallVector opUses; + + for (auto &use: result.getUses()) + opUses.push_back(&use); + for (auto use : opUses) + markLoadUsers(use, rewriter); + + if (!isPermuted) { + return rewriter.notifyMatchFailure( + op, "make_tensor_ptr's order is contiguous."); + } + + llvm::SmallVector blkShapeI32; + llvm::SmallVector blkShapeI64; + auto resPtrType = cast(result.getType()); + if (auto resShapedTy = dyn_cast(resPtrType.getPointeeType())) { + auto resBlkShape = resShapedTy.getShape(); + for (auto [i, v] : llvm::enumerate(resBlkShape)) { + auto reverseI = orderSize - 1 - i; + blkShapeI32.push_back(resBlkShape[order[reverseI]]); + blkShapeI64.push_back(resBlkShape[order[reverseI]]); + } + } + + llvm::SmallVector newShape; + llvm::SmallVector newStrides; + llvm::SmallVector newOffsets; + for (int i = orderSize - 1; i >= 0; i--) { + newShape.push_back(shape[order[i]]); + newStrides.push_back(strides[order[i]]); + newOffsets.push_back(offsets[order[i]]); + } + + llvm::SmallVector contiguousOrder; + for (int i = orderSize - 1; i >= 0; i--) + contiguousOrder.push_back(i); + + rewriter.setInsertionPoint(op); + auto newMakeTensorPtrOp = rewriter.create( + loc, base, ValueRange(newShape), ValueRange(newStrides), + ValueRange(newOffsets), blkShapeI32, contiguousOrder); + rewriter.replaceOp(op, newMakeTensorPtrOp); + for (auto use : opUses) + rewriteUserWithNewOrder(use, rewriter, blkShapeI64, loc, order, orderSize); + return success(); +} + +LogicalResult ReduceSingleCanonicalizer::matchAndRewrite(triton::ReduceOp reduceOp, PatternRewriter &rewriter) const +{ + auto srcs = reduceOp.getSrcs(); + bool allSrcSingleElem = true; + for (auto src : srcs) { + auto srcType = cast(src.getType()); + auto srcShape = srcType.getShape(); + int64_t numel = 1; + for (auto s : srcShape) { + numel *= s; + } + if (numel != 1) { + allSrcSingleElem = false; + break; + } + } + + if (!allSrcSingleElem) { + return rewriter.notifyMatchFailure(reduceOp, "reduce's srcs are not all with single element"); + } + + auto results = reduceOp.getResult(); + auto loc = reduceOp->getLoc(); + auto zero = rewriter + .create(loc, rewriter.getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), 0)) + .getResult(); + for (int i = 0; i < srcs.size(); i++) { + auto src = srcs[i]; + auto srcType = cast(src.getType()); + auto srcRank = srcType.getRank(); + auto res = results[i]; + Value extracted; + if (srcRank == 1) { + // vector reduce generates a scalar result + extracted = rewriter.create(loc, src, zero).getResult(); + } else { + auto srcShape = srcType.getShape(); + auto resType = cast(res.getType()); + auto resShape = resType.getShape(); + auto collapseReassociationIndicesOptional = getReassociationIndicesForCollapse(srcShape, resShape); + if (!collapseReassociationIndicesOptional.has_value()) { + return rewriter.notifyMatchFailure(reduceOp, "Failure with getReassociationIndicesForCollapse call"); + } + auto collapseReassociationIndices = collapseReassociationIndicesOptional.value(); + extracted = rewriter.create(loc, src, collapseReassociationIndices).getResult(); + } + res.replaceAllUsesWith(extracted); + } + + return success(); +} + +LogicalResult DenseConstantConverter::matchAndRewrite( + arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto denseAttr = cast(op.getValue()); + auto loc = op.getLoc(); + auto constSplatOp = arith::ConstantOp::materialize( + rewriter, denseAttr.getSplatValue(), + denseAttr.getElementType(), loc); + auto emptyOp = rewriter.create( + loc, cast(op.getResult().getType()).getShape(), + denseAttr.getElementType()); + + rewriter.replaceOpWithNewOp(op, ValueRange{constSplatOp}, + ValueRange{emptyOp}); + + return success(); +} + +LogicalResult +MakeRangeConverter::matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto type = cast(op.getResult().getType()); + auto shape = type.getShape(); + auto elementType = type.getElementType(); + auto context = op.getContext(); + + assert(type.getShape().size() == 1 && + isa(type.getElementType()) && + type.getElementType().getIntOrFloatBitWidth() == 32 && + "make range can only return 1D int32 tensor"); + + SmallVector indexingMaps{AffineMap::get( + /* dimCount */ 1, /* symbolCount */ 0, + {mlir::getAffineDimExpr(0, context)}, context)}; + + auto init = rewriter.create(loc, shape, elementType); + + auto nestedBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value index = nestedBuilder.create(loc, 0); + Value res = nestedBuilder.create( + loc, elementType, index); + nestedBuilder.create(loc, res); + }; + + auto linalgOp = rewriter.create( + loc, op->getResultTypes(), /* operands */ ValueRange{}, ValueRange{init}, + indexingMaps, ConverterUtils::getNParallelLoopsAttrs(1), nestedBody); + + int32_t startVal = op.getStartAttr().getInt(); + if (startVal == 0) { + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); + } + + // Apply start offset + Value startScaler = rewriter.create( + loc, rewriter.getI32IntegerAttr(static_cast(startVal))); + auto startInit = rewriter.create(loc, shape, elementType); + Value startTensor = rewriter.create( + loc, ValueRange{startScaler}, ValueRange{startInit}).getResult(0); + auto addOp = rewriter.create(loc, linalgOp->getResult(0), + startTensor); + rewriter.replaceOp(op, addOp); + return success(); +} + +LogicalResult +SplatConverter::matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto init = rewriter.create(loc, op.getType().getShape(), + op.getType().getElementType()); + rewriter.replaceOpWithNewOp(op, ValueRange{adaptor.getSrc()}, + ValueRange{init}); + return success(); +} + +LogicalResult +ReshapeConverter::matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto src = op.getSrc(); + auto dst = op.getResult(); + Value shape = rewriter.create( + loc, + rewriter.getI64TensorAttr(cast(dst.getType()).getShape())); + auto reshapeOp = + rewriter.create(loc, dst.getType(), src, shape); + rewriter.replaceOp(op, reshapeOp.getResult()); + return success(); +} + +LogicalResult ExpandDimsConverter::matchAndRewrite( + triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto src = op.getSrc(); + auto resShape = cast(op.getResult().getType()).getShape(); + auto axis = op.getAxis(); + + SmallVector reassociation; + + auto src_last_dim = resShape.size() - 2; + auto map_func = [&](unsigned i) -> ReassociationIndices { + if (i < axis) { + return i == src_last_dim ? ReassociationIndices{i, i + 1} + : ReassociationIndices{i}; + } + return i == axis ? ReassociationIndices{i, i + 1} + : ReassociationIndices{i + 1}; + }; + + reassociation = llvm::to_vector( + llvm::map_range(llvm::seq(0, src_last_dim + 1), map_func)); + + auto expandShapeOp = rewriter.create( + op.getLoc(), op.getResult().getType(), src, reassociation); + rewriter.replaceOp(op, expandShapeOp.getResult()); + return success(); +} + +LogicalResult +ClampFConverter::matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto input = adaptor.getX(); + auto min_para = adaptor.getMin(); + auto max_para = adaptor.getMax(); + auto propagateNan_para = adaptor.getPropagateNan(); + + if (auto input_type = dyn_cast(input.getType())) { + if (isa(min_para.getType())) { + auto minEmptyTensor = rewriter.create( + loc, input_type.getShape(), input_type.getElementType()); + min_para = rewriter + .create(loc, ValueRange{min_para}, + ValueRange{minEmptyTensor}) + .result(); + } + if (isa(max_para.getType())) { + auto maxEmptyTensor = rewriter.create( + loc, input_type.getShape(), input_type.getElementType()); + max_para = rewriter + .create(loc, ValueRange{max_para}, + ValueRange{maxEmptyTensor}) + .result(); + } + } + + if (propagateNan_para == PropagateNan::NONE) { + auto minOp = rewriter.create(loc, input, max_para); + auto maxOp = rewriter.create(loc, min_para, minOp); + rewriter.replaceOp(op, ValueRange{maxOp}); + } else if (propagateNan_para == PropagateNan::ALL) { + auto minOp = rewriter.create(loc, input, max_para); + auto maxOp = rewriter.create(loc, min_para, minOp); + rewriter.replaceOp(op, ValueRange{maxOp}); + } else { + return failure(); + } + + return success(); +} + +// Here convert tt.broadcast to linalg.broadcast +// +// before +// %out = tt.broadcast %in : tensor<1x4x8xf32> -> tensor<128x4x8xf32> +// +// after +// %collpased = tensor.collapse_shape %in [[0, 1], [2]] : +// tensor<1x4x8xf32> into tensor<4x8xf32> +// %out = linalg.broadcast ins(%collpased : tensor<4x8xf32>) +// outs(%empty : tensor<128x4x8xf32>) dimensions = [0] +LogicalResult +BroadcastConverter::matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(op->getNumResults() == 1 && "BroadcastOp assumes single result"); + + RankedTensorType sourceType = + cast(adaptor.getSrc().getType()); + RankedTensorType resultType = cast(op.getType()); + auto elementType = resultType.getElementType(); + auto loc = op.getLoc(); + + auto initEmpty = + rewriter.create(loc, resultType.getShape(), elementType); + + SmallVector broadcastDims = + ConverterUtils::getBroadcastDims(sourceType, resultType); + SmallVector unbroadcastDims = + ConverterUtils::getUnbroadcastDims(sourceType, resultType); + + SmallVector collapseReassociationIndices; + auto collapseReassociationIndicesOptional = + getReassociationIndicesForCollapse(sourceType.getShape(), + unbroadcastDims); + if (!collapseReassociationIndicesOptional.has_value()) { + return rewriter.notifyMatchFailure( + op, "Failure with getReassociationIndicesForCollapse call"); + } + collapseReassociationIndices = collapseReassociationIndicesOptional.value(); + + RankedTensorType collapseResultType = + RankedTensorType::get(unbroadcastDims, sourceType.getElementType()); + + auto collpasedOp = rewriter.create( + loc, collapseResultType, adaptor.getSrc(), collapseReassociationIndices); + + auto broadcastOp = rewriter.create( + loc, collpasedOp, initEmpty, + rewriter.getDenseI64ArrayAttr(broadcastDims)); + + rewriter.replaceOp(op, broadcastOp.getResults()); + return success(); +} + +// Reduce Converter +bool ReduceConverter::isReductionOpSupported(Operation *redOp) const { + return isa(redOp); +} + +LogicalResult ReduceConverter::convertToTargetOp( + triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto source = adaptor.getOperands().front(); + auto sourceType = cast(source.getType()); + auto elemType = sourceType.getElementType(); + auto resType = op.getResult().front().getType(); + auto loc = op.getLoc(); + auto reductionOps = this->getRedOps(op); + + // Reduction of arbitrary operations isn't supported because using the first + // element across the reduction dimension requires us to iterate over a + // subview that skips over each first element. + if (!this->isReductionOpSupported(reductionOps.front())) { + return rewriter.notifyMatchFailure( + op, "Only support lowering reduction with single op and limited types of reducetion"); + } + + auto rop = reductionOps.front(); + auto axis = op.getAxis(); + auto isVectorReduce = sourceType.getRank() == 1; + + auto constantType = elemType; + + auto accBaseConstOp = this->getRedBaseConstOp(rewriter, rop, constantType); + Value initTensor; + + if (isVectorReduce) { + auto holder = rewriter.create( + loc, RankedTensorType::get({}, constantType), ValueRange{}); + initTensor = rewriter + .create(loc, accBaseConstOp.getResult(), + holder.getResult()) + .getResult(0); + } else { + Value init = rewriter.create( + loc, cast(resType).getShape(), constantType); + initTensor = + rewriter.create(loc, accBaseConstOp.getResult(), init) + .getResult(0); + } + + Value finalResult = rewriter.create( + loc, ValueRange{source}, ValueRange{initTensor}, + SmallVector{axis}, + [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { + assert(inputs.size() == 2); + Value result = this->getRedElement(inputs[0], inputs[1], loc, rop, + opBuilder, false); + opBuilder.create(loc, result); + }) + .getResult(0); + + if (sourceType.getRank() == 1) { + finalResult = rewriter.create(loc, constantType, finalResult); + } + + rewriter.replaceOp(op, finalResult); + return success(); +} + +LogicalResult ReduceConverter::convertToTargetOpExtended( + triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto elemTypes = op.getElementTypes(); + + auto valueResultType = dyn_cast(op.getType(0)); + const auto isScalarReduce = valueResultType == nullptr; + + SmallVector outputs; + for (auto i = 0; i < op.getResult().size() && i < elemTypes.size(); i++) { + auto result = dyn_cast(op.getType(i)); + SmallVector resultShape{ + isScalarReduce ? SmallVector{} + : SmallVector(result.getShape())}; + outputs.push_back( + rewriter.create(loc, resultShape, elemTypes[i])); + } + + auto linalgOp = rewriter.create( + loc, adaptor.getOperands(), outputs, + SmallVector{adaptor.getAxis()}, + [&](OpBuilder &b, Location loc, ValueRange inputs) { + auto tritonReduceBlock = op.getBody(); + IRMapping mapping; + mapping.map(tritonReduceBlock->getArguments(), inputs); + + for (auto &op : tritonReduceBlock->without_terminator()) { + b.clone(op, mapping); + } + + auto tritonYield = tritonReduceBlock->getTerminator(); + auto results = + llvm::map_to_vector(tritonYield->getOperands(), + [&](Value val) { return mapping.lookup(val); }); + b.create(loc, results); + }); + + if (failed(addReduceWithIndexAttrIfNeeded(rewriter, linalgOp))) { + return rewriter.notifyMatchFailure(op, "meaningless reduce operation"); + } + + if (isScalarReduce) { + SmallVector reduceResults; + for (auto i = 0; i < linalgOp.getResults().size() && i < elemTypes.size(); + i++) { + reduceResults.push_back(rewriter.create( + loc, elemTypes[i], linalgOp.getResults()[i], ValueRange{})); + } + rewriter.replaceOp(op, reduceResults); + } else { + rewriter.replaceOp(op, linalgOp); + } + return success(); +} + +bool ScanConverter::isReductionOpSupported(Operation *redOp) const { + return isa(redOp); +} + +LogicalResult ScanConverter::convertToTargetOp( + triton::ScanOp op, typename triton::ScanOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto reductionOps = this->getRedOps(op); + if (reductionOps.empty()) { + return rewriter.notifyMatchFailure(op, "No reduction op found in scan body"); + } + + bool reverse = op.getReverse(); + if (reverse) { + op.emitError("reverse=True is not yet supported for scan op"); + return failure(); + } + + llvm::SmallString<64> funcName; + auto rop = reductionOps.front(); + if (this->isReductionOpSupported(reductionOps.front())) { + if (isa(rop)) { + funcName = "triton_cumsum"; + } else if (isa(rop)) { + funcName = "triton_cumprod"; + } + + auto moduleOp = op->getParentOfType(); + rewriter.setInsertionPoint(moduleOp.getBody(), + std::prev(moduleOp.getBody()->end())); + + auto loc = op.getLoc(); + auto src = adaptor.getOperands().front(); + auto resTy = op.getResult().front().getType(); + auto libFnType = rewriter.getFunctionType( + {src.getType(), rewriter.getI32Type(), rewriter.getI1Type()}, {resTy}); + auto funcOp = rewriter.create(loc, funcName.str(), libFnType); + + SymbolTable symTab(moduleOp); + auto maybePrintFuncNameAttr = symTab.renameToUnique(funcOp, {&symTab}); + if (failed(maybePrintFuncNameAttr)) { + return op->emitError( + "failed to create a unique func name for device_print"); + } + SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); + + rewriter.setInsertionPoint(op); + auto scanAxis = op.getAxis(); + auto scanReverse = op.getReverse(); + Value axis = rewriter.create(loc, scanAxis, 32); + Value reverseVal = rewriter.create(loc, scanReverse, 1); + auto callOp = rewriter.create(loc, funcOp.getSymNameAttr(), + TypeRange({resTy}), + ValueRange({src, axis, reverseVal})); + + rewriter.replaceOp(op, callOp); + + return success(); + } else { + // This branch is the associative_scan op. + auto loc = op.getLoc(); + + Value scanInput = op.getOperand(0); + + scanInput.dump(); + + for (Value operand : op->getOperands()) { + operand.dump(); + } + + auto srcType = mlir::dyn_cast(scanInput.getType()); + if (!srcType) { + return rewriter.notifyMatchFailure(op, "Expected RankedTensorType input for associative_scan"); + } + + auto elementType = srcType.getElementType(); + auto shape = srcType.getShape(); + int rank = shape.size(); + int axis = op.getAxis(); + + if (axis < 0 || axis >= rank) { + return rewriter.notifyMatchFailure(op, "Invalid scan axis: " + std::to_string(axis)); + } + + if (op->getNumRegions() < 1 || op->getRegion(0).empty()) { + return rewriter.notifyMatchFailure(op, "Missing combine region"); + } + + OpBuilder::InsertionGuard guard(rewriter); + + auto memrefType = MemRefType::get(shape, elementType); + Value inputMemRef = rewriter.create(loc, memrefType, scanInput); + Value outputMemRef = rewriter.create(loc, memrefType); + + auto processDimension = [&](ArrayRef baseIdxsArray) { + llvm::SmallVector baseIdxs(baseIdxsArray.begin(), baseIdxsArray.end()); + llvm::SmallVector firstIdx = baseIdxs; + if (axis <= firstIdx.size()) { + firstIdx.insert(firstIdx.begin() + axis, + rewriter.create(loc, 0)); + } else { + firstIdx.push_back(rewriter.create(loc, 0)); + } + + Value firstVal = rewriter.create(loc, inputMemRef, firstIdx); + rewriter.create(loc, firstVal, outputMemRef, firstIdx); + + Value axisSize = rewriter.create(loc, inputMemRef, axis).getResult(); + Value one = rewriter.create(loc, 1); + + Value cmp = rewriter.create(loc, arith::CmpIPredicate::sgt, axisSize, one); + auto ifOp = rewriter.create(loc, cmp, false); + + // Create a loop only when the axis size is greater than 1. + rewriter.setInsertionPointToStart(ifOp.thenBlock()); + + auto forOp = rewriter.create(loc, one, axisSize, one); + rewriter.setInsertionPointToStart(forOp.getBody()); + + Value k = forOp.getInductionVar(); + llvm::SmallVector currIdx = baseIdxs; + if (axis <= currIdx.size()) { + currIdx.insert(currIdx.begin() + axis, k); + } else { + currIdx.push_back(k); + } + + Value km1 = rewriter.create(loc, k, one); + llvm::SmallVector prevIdx = baseIdxs; + if (axis <= prevIdx.size()) { + prevIdx.insert(prevIdx.begin() + axis, km1); + } else { + prevIdx.push_back(km1); + } + + Value currentVal = rewriter.create(loc, inputMemRef, currIdx); + Value prevResult = rewriter.create(loc, outputMemRef, prevIdx); + + Region &combineRegion = op->getRegion(0); + Block &combineBlock = combineRegion.front(); + IRMapping mapping; + mapping.map(combineBlock.getArgument(0), prevResult); + mapping.map(combineBlock.getArgument(1), currentVal); + + for (Operation &innerOp : combineBlock.without_terminator()) { + rewriter.clone(innerOp, mapping); + } + + Operation *yieldOp = combineBlock.getTerminator(); + Value resultVal = mapping.lookup(yieldOp->getOperand(0)); + + rewriter.create(loc, resultVal, outputMemRef, currIdx); + + rewriter.setInsertionPointAfter(ifOp); + }; + + // Constructing loops for non-scanning dimensions + llvm::SmallVector nonScanDims; + for (int i = 0; i < rank; ++i) { + if (i != axis) nonScanDims.push_back(i); + } + + createSimpleNestedLoops(rewriter, loc, outputMemRef, nonScanDims, processDimension); + + rewriter.setInsertionPointAfter(op); + + Value outputTensor = rewriter.create(loc, outputMemRef, true); + rewriter.replaceOp(op, outputTensor); + return success(); + } +} + +LogicalResult ScanConverter::convertToTargetOpExtended( + triton::ScanOp op, typename triton::ScanOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + return op->emitError("tt.scan with multiple ops inside the body unsupported!"); +} + +LogicalResult ExternElementwiseClOpConverter::matchAndRewrite( + triton::ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + if (!op.getPure()) { + op->emitWarning() << "impure elementwise op!"; + return failure(); + } + if (op.getSymbol().contains("__hmf_")) { + // 1. get or create the declaration of external elementwise function + Type dstTy = op.getResult().getType(); + bool isDstScalar = !isa(dstTy); + Type dstElemTy = + isDstScalar ? dstTy : cast(dstTy).getElementType(); + SmallVector srcElemTys; + SmallVector srcs; + for (auto src : op.getSrcs()) { + if (!isa(src.getType())) { + src = rewriter.create( + op.getLoc(), RankedTensorType::get({(int64_t)1}, src.getType()), + src); + } + srcs.push_back(src); + srcElemTys.push_back( + cast(src.getType()).getElementType()); + } + FunctionType elemFuncType = + FunctionType::get(rewriter.getContext(), srcElemTys, {dstElemTy}); + auto mod = SymbolTable::getNearestSymbolTable(op); + auto extFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(mod, op.getSymbol())); + if (!extFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&mod->getRegion(0).front()); + extFunc = rewriter.create(rewriter.getUnknownLoc(), + op.getSymbol(), elemFuncType); + extFunc.setPrivate(); + extFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(), + UnitAttr::get(rewriter.getContext())); + } + assert(isa( + SymbolTable::lookupSymbolIn(mod, op.getSymbol()))); + // 2. prepare the output tensor + Value output; + if (isDstScalar) { + dstTy = RankedTensorType::get({(int64_t)1}, dstElemTy); + } + bool found = false; + for (Value v : srcs) { + if (v.getType() == dstTy) { + found = true; + output = v; + break; + } + } + if (!found) { + output = rewriter.create( + op.getLoc(), cast(dstTy).getShape(), dstElemTy); + } + // 3. create the linalg.map op + auto mapOp = rewriter.create( + loc, + /*inputs=*/srcs, + /*init=*/output, + /*bodyBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { + auto elemOp = builder.create(loc, + /*name=*/op.getSymbol(), + /*resultType=*/dstElemTy, + /*operands=*/regionArgs); + builder.create(loc, elemOp->getResults()); + }); + if (isDstScalar) { + // need to convert tensor back to scalar + auto indexType = rewriter.getIndexType(); + Value zeroConstant = rewriter.create( + loc, indexType, rewriter.getIntegerAttr(indexType, 0)); + auto extractOp = rewriter.create( + loc, mapOp.getResults()[0], zeroConstant); + rewriter.replaceOp(op, extractOp); + } else { + rewriter.replaceOp(op, mapOp); + } + return success(); + } + return failure(); +} + +LogicalResult UnrealizedCastConverter::matchAndRewrite( + UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.eraseOp(op); + return success(); +} + +LogicalResult +JoinConverter::matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value opa = op.getLhs(); + Value opb = op.getRhs(); + auto loc = op.getLoc(); + + auto resType = dyn_cast(op.getResult().getType()); + Value emptyOp = rewriter.create(loc, resType.getShape(), + resType.getElementType()); + + auto shape = dyn_cast(opa.getType()).getShape(); + auto sizes = llvm::map_to_vector(shape, [&](int64_t t) { + return OpFoldResult(rewriter.getI64IntegerAttr(t)); + }); + sizes.push_back(rewriter.getI64IntegerAttr(1)); + + int64_t rank = resType.getRank(); + + // Set last dimension stride to 2 in layout + // As last dimension size is always 1, last dimension stride here could be + // either 1 or 2, while stride `2` could carry interleave trait and it's + // convenient for next lower. + SmallVector strides(rank, rewriter.getIndexAttr(1)); + strides.back() = rewriter.getIndexAttr(2); + + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + + auto insert0 = rewriter.create( + loc, opa, emptyOp, offsets, sizes, strides); + + offsets.back() = rewriter.getIndexAttr(1); + auto insert1 = rewriter.create( + loc, opb, insert0, offsets, sizes, strides); + rewriter.replaceOp(op, insert1); + return success(); +} + +LogicalResult +CatConverter::matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value opa = op.getLhs(); + Value opb = op.getRhs(); + auto loc = op.getLoc(); + + auto resType = dyn_cast(op.getResult().getType()); + auto emptyOp = rewriter.create(loc, resType.getShape(), + resType.getElementType()); + + auto rank = resType.getRank(); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + + auto inputType = dyn_cast(opa.getType()); + + SmallVector sizes = + llvm::map_to_vector(inputType.getShape(), [&](int64_t t) { + return OpFoldResult(rewriter.getI64IntegerAttr(t)); + }); + + auto insert0 = rewriter.create( + loc, opa, emptyOp, offsets, sizes, strides); + + offsets[0] = + rewriter.getIndexAttr(inputType.getRank() ? inputType.getShape()[0] : 1); + auto insert1 = rewriter.create( + loc, opb, insert0, offsets, sizes, strides); + + rewriter.replaceOp(op, insert1); + return success(); +} + +/// @brief Convert tt.gather to func.call. BiShengIR captures the func +/// with assumed semantics. +/// @param op The `triton::GatherOp` operation to be rewritten. +/// @param adaptor An adaptor for the operation's operands. +/// @param rewriter A pattern rewriter used to modify the IR. +/// @return A `LogicalResult` indicating whether the rewrite was successful. +LogicalResult +GatherConverter::matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value src = adaptor.getSrc(); + Value idx = adaptor.getIndices(); + Value res = op.getResult(); + auto gatherAxis = op.getAxis(); + + auto moduleOp = op->getParentOfType(); + rewriter.setInsertionPoint(moduleOp.getBody(), + std::prev(moduleOp.getBody()->end())); + + llvm::SmallString<128> funcName = gatherFuncNameBase; + int uniqueId = 0; + while (SymbolTable::lookupSymbolIn(moduleOp, funcName)) { + funcName += "_" + std::to_string(uniqueId++); + } + + auto resTy = res.getType(); + auto libFnType = rewriter.getFunctionType( + {src.getType(), idx.getType(), rewriter.getI32Type()}, {resTy}); + auto funcOp = rewriter.create(loc, funcName.str(), libFnType); + SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); + + rewriter.setInsertionPoint(op); + Value axis = rewriter.create(loc, gatherAxis, 32); + auto callOp = rewriter.create(loc, funcOp.getSymNameAttr(), + TypeRange({resTy}), + ValueRange({src, idx, axis})); + + rewriter.replaceOp(op, callOp); + + return success(); +} + +LogicalResult +SplitConverter::matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = op.getSrc(); + auto loc = op.getLoc(); + auto inputType = cast(input.getType()); + + int64_t rank = inputType.getRank(); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + // Similar to JoinConverter, here adjust last dimension stride + SmallVector strides(rank, rewriter.getIndexAttr(1)); + strides.back() = rewriter.getIndexAttr(2); + + auto outType = dyn_cast(op.getOutLHS().getType()); + auto sizes = llvm::map_to_vector(outType.getShape(), [&](int64_t t) { + return OpFoldResult(rewriter.getIndexAttr(t)); + }); + sizes.push_back(rewriter.getIndexAttr(1)); + + auto slice0 = rewriter.create( + loc, outType, input, offsets, sizes, strides); + + offsets.back() = rewriter.getIndexAttr(1); + auto slice1 = rewriter.create( + loc, outType, input, offsets, sizes, strides); + + SmallVector slices = {slice0.getResult(), slice1.getResult()}; + rewriter.replaceOp(op, ValueRange(slices)); + return success(); +} + +/* +the element-wise most significant N bits of the 2N-bit product of x and y +%x:2 = arith.mulsi_extended %y, %z : tensor<4x?xi32> +*/ +LogicalResult TritonMulhiuiConverter::matchAndRewrite( + triton::MulhiUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value opl = op.getX(); + Value opr = op.getY(); + Value res = op.getResult(); + auto newMulOp = rewriter.create( + loc, res.getType(), res.getType(), opl, opr); + // triton only need the high value + rewriter.replaceOp(op, ValueRange{newMulOp.getHigh()}); + return success(); +} + +LogicalResult TritonPreciseSqrtConverter::matchAndRewrite( + triton::PreciseSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); +} + +LogicalResult DevicePrintConverter::matchAndRewrite( + triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto moduleOp = op->getParentOfType(); + rewriter.setInsertionPoint(moduleOp.getBody(), + std::prev(moduleOp.getBody()->end())); + SmallVector inputTypes; + for (auto arg : op.getArgs()) { + inputTypes.push_back(arg.getType()); + } + auto libFnType = rewriter.getFunctionType(inputTypes, {}); + auto funcOp = + rewriter.create(op.getLoc(), printFuncNameBase, libFnType); + SymbolTable symTab(moduleOp); + auto maybePrintFuncNameAttr = symTab.renameToUnique(funcOp, {&symTab}); + if (failed(maybePrintFuncNameAttr)) { + return op->emitError( + "failed to create a unique func name for device_print"); + } + SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); + auto prefixAttr = op.getPrefixAttr(); + funcOp->setAttr(prefixAttrName, prefixAttr); + auto hexAttr = op.getHexAttr(); + funcOp->setAttr(hexAttrName, hexAttr); + + rewriter.setInsertionPoint(op); + rewriter.create(op.getLoc(), funcOp, op.getArgs()); + + rewriter.eraseOp(op); + return success(); +} + +LogicalResult DeviceAssertConverter::matchAndRewrite( + triton::AssertOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto msgAttr = op.getMessageAttr(); + // Filter out automatically inserted assert ops + if (auto strAttr = mlir::dyn_cast(msgAttr)) { + llvm::StringRef msg = strAttr.getValue(); + if (msg.contains("overflow detected for operation")) { + rewriter.eraseOp(op); + return success(); + } + } + + auto moduleOp = op->getParentOfType(); + rewriter.setInsertionPoint(moduleOp.getBody(), + std::prev(moduleOp.getBody()->end())); + auto conditionType = op.getCondition().getType(); + + auto libFnType = rewriter.getFunctionType({conditionType}, {}); + auto funcOp = + rewriter.create(op.getLoc(), printFuncNameBase, libFnType); + mlir::SymbolTable symTab(moduleOp); + auto maybePrintFuncNameAttr = symTab.renameToUnique(funcOp, {&symTab}); + if (failed(maybePrintFuncNameAttr)) { + return op->emitError( + "failed to create a unique func name for device_assert"); + } + SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); + funcOp->setAttr(msgAttrName, msgAttr); + + rewriter.setInsertionPoint(op); + rewriter.create(op.getLoc(), funcOp, ValueRange{op.getCondition()}); + + rewriter.eraseOp(op); + return success(); +} + +LogicalResult +MatmulConverter::matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto opa = adaptor.getA(); + auto opb = adaptor.getB(); + auto opc = adaptor.getC(); + auto dstType = cast(op.getType()); + auto inputPrec = op.getInputPrecision(); + + if (dstType.getRank() == 2) { + auto matmulOp = rewriter.replaceOpWithNewOp( + op, ValueRange{opa, opb}, ValueRange{opc}); + matmulOp->setAttr( + "input_precison", + rewriter.getStringAttr(stringifyInputPrecision(inputPrec))); + } else if (dstType.getRank() == 3) { + auto matmulOp = rewriter.replaceOpWithNewOp( + op, ValueRange{opa, opb}, ValueRange{opc}); + matmulOp->setAttr( + "input_precison", + rewriter.getStringAttr(stringifyInputPrecision(inputPrec))); + } else { + llvm_unreachable("Datatype of DotOp operands could only be 2D or 3D"); + } + return success(); +} + + +LogicalResult SortOpConverter::matchAndRewrite( + triton::SortOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + { + Value src = adaptor.getSrc(); + auto rankedSrcTy = cast(src.getType()); + auto srcElemTy = rankedSrcTy.getElementType(); + auto srcShape = rankedSrcTy.getShape(); + auto srcEnc = rankedSrcTy.getEncoding(); + + MLIRContext *ctx = rewriter.getContext(); + + Type backendElemTy = srcElemTy; + if (srcElemTy.isInteger(8)) { + backendElemTy = Float16Type::get(ctx); // i8 -> f16 + } else if (srcElemTy.isInteger(16)) { + backendElemTy = Float32Type::get(ctx); // i16 -> f32 + } + Type backendTensorTy = RankedTensorType::get(srcShape, backendElemTy, srcEnc); + + Type valuesTy = src.getType(); + + Location loc = op.getLoc(); + auto dimAttr = op->getAttrOfType("dim"); + auto descAttr = op->getAttrOfType("descending"); + if (!dimAttr || !descAttr) { + op->emitError("missing 'dim' or 'descending' attribute"); + return failure(); + } + + auto moduleOp = op->getParentOfType(); + if (!moduleOp) { + op->emitError("must be inside a module"); + return failure(); + } + + llvm::SmallString<64> baseName("triton_sort"); + llvm::SmallString<64> funcName = baseName; + int uniqueId = 0; + while (SymbolTable::lookupSymbolIn(moduleOp, funcName)) { + funcName = baseName; + funcName += ("_" + std::to_string(uniqueId++)); + } + + auto i64Ty = IntegerType::get(ctx, 64); + auto i1Ty = IntegerType::get(ctx, 1); + auto libFnType = rewriter.getFunctionType( + {backendTensorTy, i64Ty, i1Ty}, + {backendTensorTy}); + + auto moduleIP = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToEnd(moduleOp.getBody()); + auto funcOp = rewriter.create(loc, funcName.str(), libFnType); + SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); + rewriter.restoreInsertionPoint(moduleIP); + + Value srcForCall = src; + if (backendElemTy != srcElemTy) { + srcForCall = rewriter.create(loc, backendTensorTy, src); + } + + Value dimVal = rewriter.create(loc, dimAttr.getInt(), 64); + Value descVal = rewriter.create(loc, descAttr.getValue() ? 1 : 0, 1); + + auto callee = SymbolRefAttr::get(ctx, funcOp.getSymName()); + auto callOp = rewriter.create( + loc, + TypeRange({backendTensorTy}), + callee, + ValueRange({srcForCall, dimVal, descVal}) + ); + + Value valuesFloat = callOp.getResult(0); // tensor + + Value finalValues = valuesFloat; + if (backendElemTy != srcElemTy) { + finalValues = rewriter.create(loc, valuesTy, valuesFloat); + } + + rewriter.replaceOp(op, {finalValues}); + + return success(); +} + + +LogicalResult +DotScaledConverter::matchAndRewrite(triton::DotScaledOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const + { + Value lhs = adaptor.getLhs(); + Value lhsScale = adaptor.getLhsScale(); + Value rhsScale = adaptor.getRhsScale(); + Value rhs = adaptor.getRhs(); + Value c = adaptor.getC(); + RankedTensorType dstType = cast(op.getType()); + + RankedTensorType lhsTy = cast(lhs.getType()); + RankedTensorType lhsScaleTy = cast(lhsScale.getType()); + RankedTensorType rhsScaleTy = rhsScale ? cast(rhsScale.getType()) : nullptr; + RankedTensorType rhsTy = cast(rhs.getType()); + + Value lhsScaleOut; + Value rhsScaleOut; + Value c127 = rewriter.create( + op.getLoc(), + rewriter.getI16Type(), + rewriter.getI16IntegerAttr(127) + ); + Value c7 = rewriter.create( + op.getLoc(), + rewriter.getI16Type(), + rewriter.getI16IntegerAttr(7) + ); + Type i16Ty = rewriter.getI16Type(); + Type bf16Ty = rewriter.getBF16Type(); + Type fp16Ty = rewriter.getF16Type(); + Type fp32Ty = rewriter.getF32Type(); + + if (lhsScaleTy.getElementType().isIntOrIndex()) { + RankedTensorType lhsScaleI16Ty = RankedTensorType::get(lhsScaleTy.getShape(), i16Ty); + Value lhsScaleI16 = rewriter.create( + op.getLoc(), + lhsScaleI16Ty, + lhsScale + ); + + Value lhsShift127Empty = rewriter.create( + op.getLoc(), + lhsScaleI16Ty.getShape(), + i16Ty + ); + Value lhsShift127 = rewriter.create( + op.getLoc(), + ValueRange{c127}, + ValueRange{lhsShift127Empty} + ).getResult(0); + + Value lhsScaleI16Add127 = rewriter.create( + op.getLoc(), + lhsScaleI16, + lhsShift127 + ); + + Value lhsShift7Empty = rewriter.create( + op.getLoc(), + lhsScaleI16Ty.getShape(), + i16Ty + ); + Value lhsShift7 = rewriter.create( + op.getLoc(), + ValueRange{c7}, + ValueRange{lhsShift7Empty} + ).getResult(0); + Value lhsScaleI16Shifted = rewriter.create( + op.getLoc(), + lhsScaleI16Add127, + lhsShift7 + ); + + RankedTensorType lhsScaleBF16Ty = RankedTensorType::get(lhsScaleTy.getShape(), bf16Ty); + Value lhsScaleBF16 = rewriter.create( + op.getLoc(), + lhsScaleBF16Ty, + lhsScaleI16Shifted + ); + if (lhsTy.getElementType() == fp16Ty) { + RankedTensorType lhsScaleFp32Ty = RankedTensorType::get(lhsScaleTy.getShape(), fp32Ty); + Value lhsScaleFp32 = rewriter.create( + op.getLoc(), + lhsScaleFp32Ty, + lhsScaleBF16 + ); + RankedTensorType lhsScaleFp16Ty = RankedTensorType::get(lhsScaleTy.getShape(), fp16Ty); + lhsScaleOut = rewriter.create( + op.getLoc(), + lhsScaleFp16Ty, + lhsScaleFp32 + ); + } else { + lhsScaleOut = lhsScaleBF16; + } + } else { + lhsScaleOut = rewriter.create( + op.getLoc(), + RankedTensorType::get(lhsScaleTy.getShape(), fp32Ty), + lhsScale + ).getResult(); + } + + if (rhsScale && rhsScaleTy.getElementType().isIntOrIndex()) { + if (rhsScaleTy.getRank() != 2) { + return op.emitError("rhsScale must be 2D for transpose"); + } + + SmallVector transposedShape = { + rhsScaleTy.getShape()[1], + rhsScaleTy.getShape()[0] + }; + RankedTensorType transposedRhsScaleTy = RankedTensorType::get( + transposedShape, + rhsScaleTy.getElementType() + ); + + Value transposedRhsScale = rewriter.create( + op.getLoc(), + transposedRhsScaleTy, + rhsScale, + DenseI32ArrayAttr::get( + rewriter.getContext(), + ArrayRef{1, 0}) + ); + RankedTensorType rhsScaleI16Ty = RankedTensorType::get( + transposedShape, + i16Ty); + Value rhsScaleI16 = rewriter.create( + op.getLoc(), + rhsScaleI16Ty, + transposedRhsScale + ); + Value rhsShift127Empty = rewriter.create( + op.getLoc(), + rhsScaleI16Ty.getShape(), + i16Ty + ); + Value rhsShift127 = rewriter.create( + op.getLoc(), + ValueRange{c127}, + ValueRange{rhsShift127Empty} + ).getResult(0); + + Value rhsScaleI16Add127 = rewriter.create( + op.getLoc(), + rhsScaleI16, + rhsShift127 + ); + Value rhsShift7Empty = rewriter.create( + op.getLoc(), + rhsScaleI16Ty.getShape(), + i16Ty + ); + Value rhsShift7 = rewriter.create( + op.getLoc(), + ValueRange{c7}, + ValueRange{rhsShift7Empty} + ).getResult(0); + Value rhsScaleI16Shifted = rewriter.create( + op.getLoc(), + rhsScaleI16Add127, + rhsShift7 + ); + + RankedTensorType rhsScaleBF16Ty = RankedTensorType::get(transposedShape, bf16Ty); + Value rhsScaleBF16 = rewriter.create( + op.getLoc(), + rhsScaleBF16Ty, + rhsScaleI16Shifted + ); + + if (rhsTy.getElementType() == fp16Ty) { + RankedTensorType rhsScaleFp32Ty = RankedTensorType::get(transposedShape, fp32Ty); + Value rhsScaleFp32 = rewriter.create( + op.getLoc(), + rhsScaleFp32Ty, + rhsScaleBF16 + ); + RankedTensorType rhsScaleFp16Ty = RankedTensorType::get(transposedShape, fp16Ty); + rhsScaleOut = rewriter.create( + op.getLoc(), + rhsScaleFp16Ty, + rhsScaleFp32 + ); + } else { + rhsScaleOut = rhsScaleBF16; + } + int64_t rhsD0 = rhsScaleTy.getShape()[1]; + int64_t rhsD1 = rhsScaleTy.getShape()[0]; + SmallVector rhsExpandedShape1 = {rhsD0, rhsD1, 1}; + RankedTensorType rhsExpandedTy1 = RankedTensorType::get(rhsExpandedShape1, rhsTy.getElementType()); + Value rhsExpanded1 = rewriter.create( + op.getLoc(), + rhsExpandedTy1, + rhsScaleOut, + rewriter.getI32IntegerAttr(2) + ).getResult(); + + int64_t rhsDim1 = rhsTy.getShape()[0]; + if (rhsDim1 % rhsD0 != 0) { + return op.emitError("rhs dim0 must be an integer multiple of rhsScale dim0"); + } + int64_t rhsD2 = rhsDim1 / rhsD0; + SmallVector rhsBroadcastShape = {rhsD0, rhsD1, rhsD2}; + RankedTensorType rhsBroadcastTy = RankedTensorType::get(rhsBroadcastShape, rhsTy.getElementType()); + Value rhsBroadcasted = rewriter.create( + op.getLoc(), + rhsBroadcastTy, + rhsExpanded1 + ).getResult(); + + SmallVector transposeOrder = {0, 2, 1}; + Value transposedBroadcasted = rewriter.create( + op.getLoc(), + RankedTensorType::get({rhsD0, rhsD2, rhsD1}, rhsTy.getElementType()), + rhsBroadcasted, + DenseI32ArrayAttr::get(rewriter.getContext(), transposeOrder) + ); + SmallVector rhsReassociation; + rhsReassociation.push_back({0, 1}); + rhsReassociation.push_back({2}); + + Value scaledRhs = rewriter.create( + op.getLoc(), + RankedTensorType::get({rhsD0 * rhsD2, rhsD1}, rhsTy.getElementType()), + transposedBroadcasted, + rhsReassociation + ).getResult(); + + rhs = rewriter.create( + op.getLoc(), + rhs, + scaledRhs + ).getResult(); + } + + int64_t D0 = lhsScaleTy.getShape()[0]; + int64_t D1 = lhsScaleTy.getShape()[1]; + SmallVector expandedShape1 = {D0, D1, 1}; + RankedTensorType expandedTy1 = RankedTensorType::get(expandedShape1, lhsTy.getElementType()); + Value expanded1 = rewriter.create( + op.getLoc(), + expandedTy1, + lhsScaleOut, + rewriter.getI32IntegerAttr(2) + ).getResult(); + + int64_t lhsDim1 = lhsTy.getShape()[1]; + if (lhsDim1 % D1 != 0) { + return op.emitError("lhs dim1 must be an integer multiple of lhsScale dim1"); + } + int64_t D2 = lhsDim1 / D1; + SmallVector broadcastShape = {D0, D1, D2}; + RankedTensorType broadcastTy = RankedTensorType::get(broadcastShape, lhsTy.getElementType()); + Value broadcasted = rewriter.create( + op.getLoc(), + broadcastTy, + expanded1 + ).getResult(); + + SmallVector reassociation; + reassociation.push_back({0}); + reassociation.push_back({1, 2}); + + Value scaledLhs = rewriter.create( + op.getLoc(), + RankedTensorType::get({D0, D1 * D2}, lhsTy.getElementType()), + broadcasted, + reassociation + ).getResult(); + + Value scaledLhsFinal = rewriter.create( + op.getLoc(), + lhs, + scaledLhs + ).getResult(); + + Operation *matmulOp; + if (dstType.getRank() == 2) { + matmulOp = rewriter.create( + op.getLoc(), ValueRange{scaledLhsFinal, rhs}, ValueRange{c} + ); + } else if (dstType.getRank() == 3) { + matmulOp = rewriter.create( + op.getLoc(), ValueRange{scaledLhsFinal, rhs}, ValueRange{c} + ); + } else { + return op.emitError("DotScaledOp only support 2D or 3D tensor"); + } + + rewriter.replaceOp(op, matmulOp->getResults()); + return success(); +} + +LogicalResult +PtrToIntConverter::matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value ptr = adaptor.getSrc(); + + if (!mlir::isa(ptr.getType())) { + return rewriter.notifyMatchFailure(op, "input is not a memref type"); + } + + auto resultType = op.getType(); + + // memref.extract_aligned_pointer_as_index is used to obtain the integer representation of the base address. + auto ptrToIndexOp = rewriter.create( + loc, ptr); + + Value intResult = rewriter.create( + loc, resultType, ptrToIndexOp); + + rewriter.replaceOp(op, intResult); + return success(); +} + +} // namespace TTOpConverters diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp new file mode 100644 index 000000000..aa7046546 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp @@ -0,0 +1,911 @@ +#include "TritonToLinalg/TritonToLinalgPass.h" +#include "TritonToLinalg/ArgMinMaxConverter.h" +#include "TritonToLinalg/FunctionConverter.h" +#include "TritonToLinalg/LoadStoreConverter.h" +#include "TritonToLinalg/TritonOpConverter.h" +#include "TritonToLinalg/DescriptorConverter.h" +#include "TritonToLinalg/UseAnalysis.h" +#include "Utils/InterleaveOptimization.h" +#include "Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" + +#include +#include +#include + +#define DEBUG_TYPE "triton-to-linalg" + +using namespace mlir; +using namespace triton; + +int nd2nzFlag = 0; +bool existDotFlag = false; + +TritonTypeConverter::TritonTypeConverter() { + addConversion([](Type type) { return type; }); + + addConversion([](triton::PointerType ptrType) { + return MemRefType::get({ShapedType::kDynamic}, ptrType.getPointeeType()); + }); + + addConversion([](TensorType tensorType) -> Type { + auto elemType = tensorType.getElementType(); + if (auto ptrType = dyn_cast(elemType)) { + elemType = ptrType.getPointeeType(); + } + return MemRefType::get(tensorType.getShape(), elemType); + }); +} + +void TritonToLinalgPass::addProgramInfo(triton::FuncOp func, + bool globalKernel) { + OpBuilder b(func); + + auto origFuncType = func.getFunctionType(); + auto origInputTypes = origFuncType.getInputs(); + SmallVector newInputTypes(origInputTypes); + newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); + + auto newFuncType = + b.getFunctionType(newInputTypes, origFuncType.getResults()); + + func.setFunctionType(newFuncType); + + // 如果需要,给参数新增属性 + if (func.getAllArgAttrs()) { + SmallVector newArgAttrs; + func.getAllArgAttrs(newArgAttrs); + newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); + func.setAllArgAttrs(newArgAttrs); + } + + // 添加对应参数到函数体中 + for (unsigned i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { + func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); + } + + if (globalKernel) { + func->setAttr(globalKernelAttr, b.getStringAttr("")); + } else { + func->setAttr(globalKernelAttr, b.getStringAttr("local")); + } +} + +static void setBlockArgumentAttr(BlockArgument blockArg, triton::FuncOp func, TensorKind tensorKind) +{ + unsigned argIdx = blockArg.getArgNumber(); + auto existingAttr = func.getArgAttrOfType(argIdx, "tt.tensor_kind"); + TensorKind oldVal = existingAttr ? static_cast(existingAttr.getInt()) : TensorKind::NONE; + + TensorKind finalVal = tensorKind; + if ((oldVal == TensorKind::INPUT && tensorKind == TensorKind::OUTPUT) || + (oldVal == TensorKind::OUTPUT && tensorKind == TensorKind::INPUT)) { + finalVal = TensorKind::INPUT_OUTPUT; + } else if (oldVal == TensorKind::INPUT_OUTPUT) { + finalVal = oldVal; + } + + func.setArgAttr(argIdx, "tt.tensor_kind", + IntegerAttr::get(IntegerType::get(func.getContext(), INT_BIT_WIDTH), static_cast(finalVal))); +} + +template +void TritonToLinalgPass::addTensorKindToArguments(OpTy op, triton::FuncOp func, TensorKind tensorKind) +{ + Value ptr = op.getPtr(); + if (!ptr) + return; + + Value cur = ptr; + llvm::SmallPtrSet visited; + // 回溯 def-use 链,找到起源 BlockArgument + while (visited.insert(cur).second) { + // 如果是 BlockArgument,则尝试设置属性 + if (auto blockArg = dyn_cast(cur)) { + if (blockArg.getOwner() == &func.getBody().front()) { + auto type = blockArg.getType(); + // 检查是否是 triton::PointerType + if (!isa(type)) + break; + setBlockArgumentAttr(blockArg, func, tensorKind); + break; + } + } + + Operation *defOp = cur.getDefiningOp(); + if (!defOp) + break; + cur = defOp->getOperand(0); + } +} + +LogicalResult +TritonToLinalgPass::convertMultipleBlockControlFlow(Operation *funcOp, + OpBuilder &builder) { + if (!isa(funcOp)) { + funcOp->emitError("convertMultipleBlockControlFlow can only process func::FuncOp!"); + return failure(); + } + + SmallVector candidate; + SmallVector eraseBlocks; + for (Block &block : dyn_cast(funcOp).getBody()) { + auto curTerminator = block.getTerminator(); + if (isa(curTerminator)) + candidate.push_back(curTerminator); + else if (isa(curTerminator)){ + if(candidate.empty()){ + curTerminator->emitError("funcOp has more than one Block but got a early 'tt.return' Op."); + return failure(); + } + } + else + return failure(); + + if (!block.isEntryBlock()) + eraseBlocks.push_back(&block); + } + + if(candidate.empty()){ + funcOp->emitError("funcOp has more than one Block but no candidate Terminator was found!"); + return failure(); + } + + llvm::BitVector visitFlag(candidate.size(), false); + + // Recursive function to convert all cf::CondBranchOp to scf::IfOp + std::function convertToSCF = + [&](Operation *op, Operation *insertPosOp) -> void { + auto condBranchOp = dyn_cast_if_present(op); + auto iter = llvm::find(candidate, condBranchOp); + if (!(condBranchOp && iter != candidate.end())) { + op->emitError("convertToSCF must process with condBranchOp in candidates!"); + return; + } + visitFlag.set(iter - candidate.begin()); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(insertPosOp); + + // Well, here force to destory original control flow + builder.create( + condBranchOp->getLoc(), condBranchOp.getCondition(), + /*thenBuilder=*/ + [&](OpBuilder &builder, Location loc) { + SmallVector movedOps = llvm::map_to_vector( + condBranchOp.getTrueDest()->without_terminator(), + [](Operation &op) { return &op; }); + for (auto *innerOp : movedOps) { + innerOp->moveBefore(builder.getInsertionBlock(), + builder.getInsertionPoint()); + } + + auto blockTerm = condBranchOp.getTrueDest()->getTerminator(); + if (isa(blockTerm)) { + if (movedOps.empty()) { + blockTerm->emitError("movedOps can not be empty before entering convertToSCF!"); + return; + } + convertToSCF(blockTerm, movedOps.back()); + } + + builder.create(loc); + }, + /*elseBuilder=*/ + [&](OpBuilder &builder, Location loc) { + SmallVector movedOps = llvm::map_to_vector( + condBranchOp.getFalseDest()->without_terminator(), + [](Operation &op) { return &op; }); + for (auto *innerOp : movedOps) { + innerOp->moveBefore(builder.getInsertionBlock(), + builder.getInsertionPoint()); + } + + auto blockTerm = condBranchOp.getFalseDest()->getTerminator(); + if (isa(blockTerm)) { + if (movedOps.empty()) { + blockTerm->emitError("movedOps can not be empty before entering convertToSCF!"); + return; + } + convertToSCF(blockTerm, movedOps.back()); + } + + builder.create(loc); + }); + }; + + Block::iterator insertOp(candidate.front()); + --insertOp; + convertToSCF(candidate.front(), &(*insertOp)); + + if (!visitFlag.all()) + return failure(); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(candidate.front()); + builder.create(candidate.front()->getLoc()); + + for (Operation *eachTerm : candidate) + eachTerm->erase(); + for (Block *block : llvm::reverse(eraseBlocks)) + block->erase(); + + return success(); +} + +void TritonToLinalgPass::convertTTFunc(triton::FuncOp func, + const bool existDot) { + OpBuilder builder(func); + + auto name = func.getName(); + auto type = func.getFunctionType(); + + SmallVector argAttrs, resAttrs; + func.getAllArgAttrs(argAttrs); + func.getAllResultAttrs(resAttrs); + + // bit-casted tt.ptr的特殊处理 + SmallVector inputTypes{type.getInputs()}; + SmallVector retTypes{type.getResults()}; + if (func.getSymVisibility() == "public" && !func.isDeclaration()) { + for (size_t i = 0; i < func.getNumArguments(); ++i) { + auto arg = func.getArgument(i); + // Special method for i1 arg + if (!isa(arg.getType()) || + dyn_cast(arg.getType()).getElementTypeBitWidth() != + 1) { + continue; + } + + SmallVector argVaildUser{arg.getUsers()}; + llvm::erase_if(argVaildUser, [](Operation *op) -> bool { + return isOpTriviallyDead(op); + }); + + if (!argVaildUser.empty()) { + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << arg << " has users:\n"; + int cnt = 0; + for (auto it : argVaildUser) { + os << "users[" << cnt++ << "] = " << *it; + } + }); + if (llvm::all_of(argVaildUser, [](Operation *userOp) { + return isa(userOp); + })) { + auto castOp = cast(*argVaildUser.begin()); + if (castOp.getInputs().size() == 1 && + castOp.getOutputs().size() == 1) { + arg.setType(castOp.getOutputs()[0].getType()); + inputTypes[i] = arg.getType(); + } + } else { + func->emitError(Twine("Unsupported use of func arg at index ") + + Twine(i)); + } + } else { + // Process unused bool ptr type specially, which guarantees bool pointer + // argument's type is realistic and don't mislead backend compiler. + // realistic memory layout of bool pointer is 8 bit width + auto memType = dyn_cast(arg.getType()) + .cloneWith(std::nullopt, builder.getI8Type()); + arg.setType(memType); + inputTypes[i] = arg.getType(); + } + } + } + auto castType = FunctionType::get(func.getContext(), inputTypes, retTypes); + + auto funcFunc = builder.create(func.getLoc(), name, castType); + funcFunc.setAllArgAttrs(argAttrs); + funcFunc.setAllResultAttrs(resAttrs); + auto kernelAttr = func->getAttr(globalKernelAttr); + if (kernelAttr) { + funcFunc->setAttr(globalKernelAttr, kernelAttr); + } + std::string kernelMixMode = "aiv"; + if (existDot) { + // mix also works for pure cube kernel by using the same MAGIC_ELF keyword + kernelMixMode = "mix"; + } + // Set mix_mode in the func attrs so that the backend could know + // the mix_mode by parse the func attrs. + // The backend needs to know the mix_mode because the host wrapper + // needs to set the devbin.magic. Check npu_utils.cpp. + funcFunc->setAttr(kernelMixModeName, builder.getStringAttr(kernelMixMode)); + + auto &funcFuncBody = funcFunc.getBody(); + auto &funcBody = func.getBody(); + + IRMapping map; + funcBody.cloneInto(&funcFuncBody, map); + + if (!funcFuncBody.hasOneBlock()) { + if (failed(convertMultipleBlockControlFlow(funcFunc, builder))) { + llvm_unreachable("Encounter unsupported control flow"); + } + } + + for (Block &block : funcFuncBody.getBlocks()) { + auto term = block.getTerminator(); + builder.setInsertionPoint(term); + builder.create(func.getLoc(), term->getOperands()); + term->erase(); + } + func.erase(); +} + + +void TritonToLinalgPass::addDynamicLegal( + ConversionTarget &target, TritonTypeConverter &tritonTypeConverter) { + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, LLVM::LLVMDialect, + bufferization::BufferizationDialect, memref::MemRefDialect, + annotation::AnnotationDialect, hivm::HIVMDialect>(); + + // add legal dialect on condition + target.addLegalOp(); + + // 根据条件判断需要转换的OP + target.addDynamicallyLegalOp( + [](mlir::Operation *op) { + if (op->use_empty()) { + return false; + } else { + return true; + } + }); + + target.addDynamicallyLegalOp([&](triton::FuncOp op) { + return tritonTypeConverter.isSignatureLegal(op.getFunctionType()); + }); + + target.addDynamicallyLegalOp([](arith::ConstantOp op) { + auto res = op.getResult(); + if (!isa(res.getType())) { + return true; + } + + if (auto denseAttr = dyn_cast(op.getValue())) { + if (!denseAttr.isSplat() || + !isa(denseAttr.getElementType())) { + return true; + } + if (res.hasOneUse() && isa(*res.user_begin())) { + return true; + } + return false; + } + return true; + }); + + target.addDynamicallyLegalOp([](Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type t) { + if (isa(t)) { + return false; + } + if (auto shapedType = dyn_cast(t)) { + return shapedType.getElementType().isIntOrFloat(); + } + assert(t.isIntOrIndexOrFloat()); + return true; + }); + }); + + target.addDynamicallyLegalDialect( + [this](Operation *op) { + if (op->hasAttr("MetaUse")) { + return false; + } + + if (isa(op)) { + return true; + } + + bool operateOnTensors = + llvm::all_of(op->getOperandTypes(), + [](Type type) { return isa(type); }); + + return this->namedOps || !operateOnTensors; + }); +} + +void TritonToLinalgPass::populateTritonToLinalgCanonicalizationPatterns(RewritePatternSet &patterns) +{ + patterns.add, + LoadStoreConverter::LoadStoreCanonicalizer, + LoadStoreConverter::LoadStoreCanonicalizer, + LoadStoreConverter::LoadStoreCanonicalizer>(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add< + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + // TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer, + TTOpConverters::ScalarMathCanonicalizer + // By test, the following ops do not need canonicalization. + // TTOpConverters::ScalarMathCanonicalizer + // TTOpConverters::ScalarMathCanonicalizer + // TTOpConverters::ScalarMathCanonicalizer + >(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} + +void TritonToLinalgPass::populateTritonToLinalgConversionPatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns, + unsigned int launchGridRank) { + nd2nzFlag = this->enableNd2nzOnVector; + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + // reduce converters + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + if (!this->namedOps) { + linalg::populateElementwiseToLinalgConversionPatterns(patterns); + } +} + +void TritonToLinalgPass::getDependentDialects(DialectRegistry ®istry) const { + registry.insert(); +} + +LogicalResult TritonToLinalgPass::processDescriptorOperations(ModuleOp moduleOp) +{ + // --- ConversionTarget: 动态合法性判断 --- + mlir::ConversionTarget target(getContext()); + + // Dialect-level dynamic legality: 这些 dialect 在操作的 operand/result 中不包含 TensorDescType 时为合法 + target.addDynamicallyLegalDialect( + [](mlir::Operation *op) { + return !DescriptorConverter::hasATensorDescriptorType(op->getOperandTypes()) && + !DescriptorConverter::hasATensorDescriptorType(op->getResultTypes()); + }); + // 函数签名合法性:triton::FuncOp 的输入/输出 不含 TensorDescType 则合法 + target.addDynamicallyLegalOp([](triton::FuncOp funcOp) { + return !DescriptorConverter::hasATensorDescriptorType(funcOp.getFunctionType().getInputs()) && + !DescriptorConverter::hasATensorDescriptorType(funcOp.getFunctionType().getResults()); + }); + target.addLegalOp(); + target.addIllegalOp(); + + // --- Patterns --- + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + + mlir::ConversionConfig config; + config.buildMaterializations = true; + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns), config))) { + moduleOp->emitError("failed to convert tensor descriptor operations"); + return failure(); + } + + return success(); +} + +void TritonToLinalgPass::runOnOperation() { + auto moduleOp = getOperation(); + + // Check if the kernel contains tl.dot. Without tl.dot, + // the kernel would be pure AIV kernel. + bool existDot = false; + moduleOp.walk([&](triton::DotOp dotOp) { + existDot = true; + return WalkResult::interrupt(); + }); + moduleOp.walk([&](triton::DotScaledOp dotScaledOp) { + existDot = true; + return WalkResult::interrupt(); + }); + existDotFlag = existDot; + + RewritePatternSet canonicalizerPatterns(&getContext()); + + // Execute tensor descriptor operations conversion + if (failed(processDescriptorOperations(moduleOp))) { + signalPassFailure(); + } + + // Traverse all the triton::FuncOp to add tensor_kind attribute + moduleOp.walk([&](triton::FuncOp func) { + func.walk([&](triton::LoadOp loadOp) { + addTensorKindToArguments(loadOp, func, TensorKind::INPUT); + }); + func.walk([&](triton::StoreOp storeOp) { + addTensorKindToArguments(storeOp, func, TensorKind::OUTPUT); + }); + func.walk([&](triton::AtomicRMWOp atomicOp) { + addTensorKindToArguments(atomicOp, func, TensorKind::INPUT_OUTPUT); + }); + func.walk([&](triton::AtomicCASOp atomicOp) { + addTensorKindToArguments(atomicOp, func, TensorKind::INPUT_OUTPUT); + }); + }); + + // 1.标准化 LoadStore ScalarStoreCanonicalizer + this->populateTritonToLinalgCanonicalizationPatterns(canonicalizerPatterns); + if (failed(applyPatternsAndFoldGreedily(moduleOp, + std::move(canonicalizerPatterns)))) { + moduleOp->emitError("failed to apply Canonicalizer Patterns"); + signalPassFailure(); + } + + // 2.使用分析 + moduleOp.walk([this](triton::FuncOp op) { + if (failed(runUseAnalysis(op))) { + signalPassFailure(); + } + }); + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + TritonTypeConverter tritonTypeConverter{}; + + // 3.标注合法方言 + this->addDynamicLegal(target, tritonTypeConverter); + + // 4:标记必须转换的op,包括tt.scan + auto loopOpLegalFn = [](LoopLikeOpInterface op) { + return !op.getOperation()->hasAttr("UnhandledLoopOp"); + }; + + target.addIllegalOp(); + target.addDynamicallyLegalOp(loopOpLegalFn); + target.addDynamicallyLegalOp(loopOpLegalFn); + + // 5.对非法Op注册Converter + this->populateTritonToLinalgConversionPatterns(tritonTypeConverter, patterns, + LAUNCH_GRID_RANK); + + // 6.遍历kernel中的function,修改program id、number of programs参数 + for (auto func : getOperation().getOps()) { + addProgramInfo(func, globalKernel); + } + + moduleOp.walk([this](LoopLikeOpInterface loopOp) { + auto *op = loopOp.getOperation(); + if (!op->hasAttr("ExtractedLoadOrStore")) + op->setAttr("UnhandledLoopOp", UnitAttr::get(op->getContext())); + + for (auto res: loopOp->getResults()) { + if (auto tensorType = dyn_cast(res.getType()); + tensorType && !isa(tensorType.getElementType())) { + IRRewriter rewriter(op->getContext()); + rewriter.setInsertionPointAfter(op); + auto newVal = rewriter.create(op->getLoc(), res.getType(), res); + rewriter.replaceAllUsesExcept(res, newVal, newVal); + } + } + }); + + // 7.做Op转换 + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + moduleOp->emitError("failed to apply Convertion Patterns"); + signalPassFailure(); + } + + // 8.函数头尾转换 + moduleOp.walk( + [&](triton::FuncOp func) { this->convertTTFunc(func, existDot); }); + + // 9.清除无效代码,简化代码。 + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + + // Calculate size of PointerCastOp precisely + SmallVector castOps; + + moduleOp.walk([&](hivm::PointerCastOp op) { castOps.push_back(op); }); + + for (auto op : castOps) { + SmallVector userOps(op->getUsers().begin(), + op->getUsers().end()); + IRRewriter rewriter(&getContext()); + rewriter.setInsertionPointAfter(op); + Value addr = op.getAddrs()[0]; + auto elementType = + cast(op.getResult().getType()).getElementType(); + Value elementTypeSize; + if (auto intType = dyn_cast(elementType)) { + elementTypeSize = rewriter.create(op.getLoc(), rewriter.getIntegerAttr(addr.getType(), intType.getWidth() / 8)); + } else if (auto floatType = dyn_cast(elementType)) { + elementTypeSize = rewriter.create(op.getLoc(), rewriter.getIntegerAttr(addr.getType(), floatType.getWidth() / 8)); + } else { + llvm_unreachable("Cannot get memory size"); + } + + for (auto userOp : userOps) { + auto reinterpretCastOp = cast(userOp); + auto sizes = reinterpretCastOp.getStaticSizes(); + auto staticStrides = reinterpretCastOp.getStaticStrides(); + auto strides = reinterpretCastOp.getStrides(); + if(reinterpretCastOp.getStaticOffsets().size() != 1) + userOp->emitError("IntToPtrOp must converted to PointerCastOp of memref type"); + int64_t castOpSize = 0; + SmallVector dynamicSizes; + for (const auto &[size, stride] : llvm::zip_equal(sizes, staticStrides)) { + assert(!ShapedType::isDynamic(size)); + if (ShapedType::isDynamic(stride)) + dynamicSizes.push_back(size); + else + castOpSize = size * stride; + } + rewriter.setInsertionPoint(reinterpretCastOp); + Value dynamicSize = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(castOpSize)); + for (const auto &[size, stride] : + llvm::zip_equal(dynamicSizes, strides)) { + Value axisSize = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(size)); + axisSize = + rewriter.create(op.getLoc(), stride, axisSize); + dynamicSize = rewriter.create(op.getLoc(), dynamicSize, + axisSize); + } + Value offsetValue; + auto staticOffset = reinterpretCastOp.getStaticOffsets()[0]; + if (ShapedType::isDynamic(staticOffset)) { + offsetValue = reinterpretCastOp.getOffsets()[0]; + if (offsetValue.getType() != addr.getType()) + offsetValue = rewriter.create( + op.getLoc(), addr.getType(), offsetValue); + } else { + offsetValue = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(addr.getType(), staticOffset)); + } + offsetValue = rewriter.create(op.getLoc(), offsetValue, elementTypeSize); + Value realAddr = rewriter.create(op.getLoc(), addr, offsetValue); + auto memrefType = MemRefType::get({ShapedType::kDynamic}, elementType); + auto newCastOp = rewriter.create( + op.getLoc(), memrefType, realAddr, dynamicSize); + auto markOp = rewriter.create(op.getLoc(), + newCastOp.getResult()); + markOp->setAttr(hivm::AddressSpaceAttr::getMnemonic(), + {hivm::AddressSpaceAttr::get(rewriter.getContext(), + hivm::AddressSpace::GM)}); + rewriter.replaceOpWithNewOp( + reinterpretCastOp, + cast(reinterpretCastOp.getResult().getType()), newCastOp, + ValueRange({}), reinterpretCastOp.getSizes(), + reinterpretCastOp.getStrides(), SmallVector({0}), + reinterpretCastOp.getStaticSizes(), + reinterpretCastOp.getStaticStrides()); + } + rewriter.eraseOp(op); + } + + // Try interleave optimization + llvm::DenseMap> interleaveCandidate; + llvm::DenseMap> + interleaveCandidateWithMask; + moduleOp.walk([&](bufferization::MaterializeInDestinationOp materializeOp) { + if (auto reinterpretCastOp = + materializeOp.getDest() + .getDefiningOp()) { + if (llvm::isa(reinterpretCastOp.getSource()) && + reinterpretCastOp.getStaticStrides().back() == 2) { + interleaveCandidate[llvm::cast( + reinterpretCastOp.getSource())] + .push_back(materializeOp); + } + } + + // Difference is that converted op chain of store with mask has + // `memref::SubViewOp` + if (auto subviewOp = + materializeOp.getDest().getDefiningOp()) { + if (!llvm::isa( + materializeOp.getSource().getDefiningOp())) + return WalkResult::advance(); + + if (auto reinterpretCastOp = + subviewOp.getSource() + .getDefiningOp()) { + if (llvm::isa(reinterpretCastOp.getSource()) && + reinterpretCastOp.getStaticStrides().back() == 2) { + interleaveCandidateWithMask[llvm::cast( + reinterpretCastOp.getSource())] + .push_back(materializeOp); + } + } + } + + return WalkResult::advance(); + }); + + for (auto [blockArg, materializeVec] : interleaveCandidate) { + // Just enable optimization where exists double materializeOp with same + // block argument destination. + if (materializeVec.size() != 2) + continue; + auto result = InterleaveStatusOptimization(materializeVec); + } + + for (auto [blockArg, materializeVec] : interleaveCandidateWithMask) { + if (materializeVec.size() != 2) + continue; + auto result = InterleaveStatusWithMaskOptimization(materializeVec); + } + + // Force to add an argument at the beginning of function arguments, which + // represents stub arg for workspace. Default type is memref + for (auto func : getOperation().getOps()) { + if (!func->hasAttr("global_kernel")) + continue; + + auto context = func.getContext(); + constexpr int64_t syncBlockLockArgIdx = 0; + NamedAttribute syncBlockLockArgAttr(StringAttr::get(context, "syncBlockLock"), + UnitAttr::get(context)); + MemRefType syncBlockLockArgType = + MemRefType::get(SmallVector(1, ShapedType::kDynamic), + IntegerType::get(context, 8)); + func.insertArgument(syncBlockLockArgIdx, // argIndex + syncBlockLockArgType, // argType + nullptr, func->getLoc()); // dicAttr + func->setAttr("SyncBlockLockArgIdx", + IntegerAttr::get(IntegerType::get(&getContext(), 64), 0)); // 64: 64位整型 + + constexpr int64_t workspaceArgIdx = 1; + MemRefType workspaceArgType = + MemRefType::get(SmallVector(1, ShapedType::kDynamic), + IntegerType::get(context, 8)); + NamedAttribute workspaceArgAttr(StringAttr::get(context, "workspace"), + UnitAttr::get(context)); + + func.insertArgument(/*argIndex*/ workspaceArgIdx, + /*argType*/ workspaceArgType, + /*dicAttr*/ nullptr, func->getLoc()); + func->setAttr("WorkspaceArgIdx", + IntegerAttr::get(IntegerType::get(&getContext(), 64), 1)); // 64: 64位整型 + } + + // Fix the Location info + moduleOp.walk([&](Operation *op) { + auto loc = op->getLoc(); + if (isa(loc)) { + llvm::SmallPtrSet stopOps; + traverseForwardUpdateUserChainIf( + op, + /*conditionFn*/ + [](Operation *curOp) { return false; }, + /*stopFn*/ + [](Operation *curOp) { return !isa(curOp->getLoc()); }, + /*actionFn*/ + nullptr, stopOps); + if (stopOps.empty()) { + op->emitWarning() << *op << " and its users all have no location!"; + } else { + Operation *goodOp = *stopOps.begin(); + op->setLoc(goodOp->getLoc()); + } + } + return WalkResult::advance(); + }); +} + +std::unique_ptr> triton::createTritonToLinalgPass() { + return std::make_unique(); +} diff --git a/third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp new file mode 100644 index 000000000..a083a50eb --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp @@ -0,0 +1,460 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "TritonToLinalg/UseAnalysis.h" +#include "Utils/Utils.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace triton; +using namespace dataflow; + +#define DEBUG_TYPE "triton-use-analysis" + +std::string stringifyUseType(UseType useTy) { + std::string ret; + if (useTy == UseType::MetaUse) { + ret = "MetaUse"; + } else if (useTy == UseType::DataUse) { + ret = "DataUse"; + } else if (useTy == UseType::MixUse) { + ret = "MixUse"; + } else if (useTy == UseType::Undefined) { + ret = "Undefined"; + } + return ret; +} + +#if LLVM_VERSION_MAJOR >= 20 +LogicalResult +triton::UseAnalysis::visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) { +#else +void triton::UseAnalysis::visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) { +#endif + + if (op->getResults().size() == 1) { + auto resultType = dyn_cast(op->getResult(0).getType()); + if (resultType && isa(resultType.getElementType())) { + for (auto opnd : operands) { + propagateUse(opnd, UseType::MetaUse); + } + } + } + + TypeSwitch(op) + .Case([&](auto load) { + propagateUse(operands[0], UseType::MetaUse); + auto mask = load.getMask(); + auto other = load.getOther(); + if (mask) { + assert(mask != other && "mask and other cannot be the same"); + propagateUse(operands[1], UseType::MetaUse); + } + if (other) { + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case( + [&](auto assert) { propagateUse(operands[0], UseType::DataUse); }) + .Case([&](auto store) { + propagateUse(operands[0], UseType::MetaUse); + propagateUse(operands[1], UseType::DataUse); + auto value = store.getValue(); + auto mask = store.getMask(); + if (mask) { + assert(mask != value && "mask and data cannot be the same"); + propagateUse(operands[2], UseType::MetaUse); + } + }) + // Consider triton::AtomicRMWOp as store operation + .Case([&](auto atomicOp) { + propagateUse(operands[0], UseType::MetaUse); + propagateUse(operands[1], UseType::DataUse); + auto value = atomicOp.getVal(); + auto mask = atomicOp.getMask(); + if (mask) { + assert(mask != value && "mask and data cannot be the same"); + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case([&](auto atomicOp) { + propagateUse(operands[0], UseType::MetaUse); + propagateUse(operands[1], UseType::DataUse); + propagateUse(operands[2], UseType::DataUse); + auto value = atomicOp.getVal(); + }) + .Case([&](auto dot) { + propagateResults(operands[0], results); + propagateResults(operands[1], results); + + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) { + splat = opc.template getDefiningOp(); + } + + if (opc && splat && splat.getSrc().getDefiningOp()) { + propagateUse(operands[2], UseType::MetaUse); + } else { + propagateUse(operands[2], UseType::DataUse); + } + }) + .Case([&](auto loopOp) { + for (const auto &[yield, init, result]: llvm::zip_equal(loopOp.getYieldedValues(), loopOp.getInits(), results)) { + propagateResults(getLatticeElement(yield), {result}); + propagateResults(getLatticeElement(init), {result}); + } + }) + .Default([&](Operation *op) { + // this condition account for tt.addptr + for (auto operand : operands) { + propagateResults(operand, results); + } + }); +#if LLVM_VERSION_MAJOR >= 20 + return success(); +#endif +} + +void setMixUseRecursively(Operation *rootOp, bool applyRoot = true) { + traverseBackwardUpdateOperandChainIf( + rootOp, + // ConditionFn + [rootOp, applyRoot](Operation *curOp) { + for (auto res : curOp->getResults()) { + auto tensorType = dyn_cast(res.getType()); + if (tensorType && isa(tensorType.getElementType())) + return false; + } + return isMetaUse(curOp) && (curOp != rootOp || applyRoot); + }, + // StopFn + [rootOp](Operation *curOp) { + return isa(curOp) && curOp != rootOp; + }, + // ActionFn + [](OpBuilder &b, Operation *op) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(b.getContext())); }); + op->removeAttr("MetaUse"); + }); +} + +void postProcessLoopOp(LoopLikeOpInterface loopOp, const DataFlowSolver &solver) { + for (const auto &[res, yield, regionArg] : + llvm::zip_equal(loopOp->getResults(), loopOp.getYieldedValues(), + loopOp.getRegionIterArgs())) { + auto *defOp = yield.getDefiningOp(); + bool isMixUse = false; + if (!defOp) + continue; + std::function(Value, Value)> isIterArgMixUse = + [&](Value v, Value target) -> std::optional { + auto defOp = v.getDefiningOp(); + auto *use = solver.lookupState(v); + if (use && use->type == UseType::DataUse) + return true; + if (v == target) + return false; + if (!defOp) + return std::nullopt; + if (auto loopOp = dyn_cast(defOp)) { + auto resNum = cast(v).getResultNumber(); + auto res = isIterArgMixUse(loopOp.getInits()[resNum], target); + if (res.has_value()) { + bool isMixUse = res.value(); + Value yieldedValue = loopOp.getYieldedValues()[resNum]; + if (auto yieldDefOp = yieldedValue.getDefiningOp()) + isMixUse = isMixUse || !isMetaUse(yieldDefOp); + return isMixUse; + } + return std::nullopt; + } + for (auto oper : defOp->getOperands()) { + auto res = isIterArgMixUse(oper, target); + if (res.has_value()) + return res.value() || !isMetaUse(defOp); + } + return std::nullopt; + }; + if (solver.lookupState(res)->type == UseType::DataUse || + isIterArgMixUse(yield, regionArg).value_or(false)) + setMixUseRecursively(defOp); + } +} + +LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) { + MLIRContext *context = funcOp.getContext(); + SymbolTableCollection symbolTable; + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(funcOp))) { + return failure(); + } + auto &os = llvm::dbgs(); + // Walk the func op, convert tags on operands to tags on operations + funcOp.walk([&](Operation *op) { + LLVM_DEBUG({ os << "[UseAnalysis] op is " << *op << "\n"; }); + UseType useType = UseType::Undefined; + for (auto result : op->getResults()) { + LLVM_DEBUG({ os << "[UseAnalysis] ===> result is " << result << "\n"; }); + auto use = solver.lookupState(result); + assert(use && "Lattice value not found"); + auto thisUseType = use->type; + LLVM_DEBUG({ + os << "[UseAnalysis] ==========> useType is " + << stringifyUseType(thisUseType) << "\n"; + }); + if (thisUseType == UseType::Undefined) { + continue; + } + if (useType == UseType::Undefined) { + useType = thisUseType; + } + if (thisUseType == UseType::MixUse || thisUseType != useType) { + useType = UseType::MixUse; + break; + } + } + + if (useType == UseType::Undefined) { + LLVM_DEBUG({ op->setAttr("Undefined", UnitAttr::get(context)); }); + return; + } else if (useType == UseType::MetaUse) { + if (!isa(op)) { + assert(op->getNumResults() == 1 && + "Ops used for meta computation are expected to have one result"); + } + for (auto it = 0; it < op->getNumResults(); ++it) { + // Only set the tag if the operation uses tensors + if (isa(op->getResult(it).getType()) || + (isa(op) && + op->hasAttr(ConverterUtils::discreteAttrName)) || + (isa(op) && + isa(op->getResult(it).getType()))) { + // Setting tag for erasing op later + op->setAttr("MetaUse", UnitAttr::get(context)); + } + } + return; + } else if (useType == UseType::DataUse) { + LLVM_DEBUG({ op->setAttr("DataUse", UnitAttr::get(context)); }); + return; + } + + assert(useType == UseType::MixUse); + + // If the operation only produces scalars, no need to clone it + bool shapedResult = true; + for (auto result : op->getResults()) + shapedResult &= isa(result.getType()); + if (!shapedResult || isa(op)) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + return; + } + + llvm::SetVector metaUsers; + for (auto result : op->getResults()) { + for (auto user : result.getUsers()) { + TypeSwitch(user) + .Case([&](auto load) { + auto ptr = load.getPtr(); + auto mask = load.getMask(); + auto other = load.getOther(); + if (result == ptr || result == mask || result == other) { + metaUsers.insert(user); + } + }) + .Case([&](auto store) { + auto ptr = store.getPtr(); + auto mask = store.getMask(); + if (result == ptr || result == mask) { + metaUsers.insert(user); + } + }) + .Case([&](auto atomicOp) { + auto ptr = atomicOp.getPtr(); + auto mask = atomicOp.getMask(); + if (result == ptr || result == mask) + metaUsers.insert(user); + }) + .Case([&](auto atomicOp) { + auto ptr = atomicOp.getPtr(); + if (result == ptr) + metaUsers.insert(user); + }) + .Case([&](auto dot) { + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) { + splat = opc.template getDefiningOp(); + } + + if (opc && splat && + splat.getSrc().getDefiningOp()) { + metaUsers.insert(user); + } + }) + .Default([&](Operation *op) { + bool allMeta = true; + for (auto res : op->getResults()) { + auto resUse = solver.lookupState(res); + if (resUse->type != UseType::MetaUse) { + allMeta = false; + break; + } + } + if (allMeta) { + metaUsers.insert(user); + } + }); + } + } + + // If the operation doesn't have direct meta users, no need to clone it + if (metaUsers.empty()) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + return; + } + + // Clone the operation; switch all meta users to use the clone + OpBuilder builder(op); + auto clone = builder.clone(*op); + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + + // Setting tag for erasing op later + clone->setAttr("MetaUse", UnitAttr::get(context)); + + for (auto [res_i, result] : llvm::enumerate(op->getResults())) { + for (auto user : metaUsers) { + for (auto &operand : user->getOpOperands()) { + if (operand.get() == result) { + operand.set(clone->getResult(res_i)); + } + } + } + } + }); + LLVM_DEBUG({ + os << "[UseAnalysis] Before post-process, funcOp is " << *funcOp << "\n"; + }); + // Post-process + funcOp.walk([&](Operation *op) { + // Handle indirect load case. + // For example, load(1st) -> computeOp -> load(2nd). + // The first load is IndirectLoadInterfaceOp. + // Do not inplace replace MetaUse by MixUse. Because the condition checking + // depends on that the op has the attr of MetaUse. + // Handle the indirect load interface op + // We first trace from the 1st load to the 2nd load with the ops between + // them marked as MixUse. Then we traceback from the 2nd load to mark defs + // MixUse. + if (opIsIndirectLoad(op) || opIsIndirectCalc(op)) { + LLVM_DEBUG({ + os << "[UseAnalysis] Found indirect load interface op: " << *op << "\n"; + }); + llvm::SmallPtrSet stopOps; + // Modify the users of this op's result. + traverseForwardUpdateUserChainIf( + op, + /*conditionFn*/ + [op](Operation *curOp) { return isMetaUse(curOp) && curOp != op; }, + /*stopFn*/ + [&](Operation *curOp) { + // triton::LoadOp without MetaUse means it is an indirect load + // instead of the load providing the offset. + // The pattern is as follows, + // load -> ops -> load + // We need to ensure the intermediate ops are marked MixUse + // so that they will be replaced instead of be erased without + // conversion. + return isa(curOp) && !isMetaUse(curOp); + }, + /*actionFn*/ + [](OpBuilder &b, Operation *op) { + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(b.getContext())); }); + op->removeAttr("MetaUse"); + }, + stopOps); + LLVM_DEBUG({ + os << "[UseAnalysis] stopOps are \n"; + for (auto [idx, stopOp] : llvm::enumerate(stopOps)) + os << idx << ": " << *stopOp << "\n"; + }); + LLVM_DEBUG({ + os << "[UseAnalysis] After trace, funcOp is " << *funcOp << "\n"; + }); + for (auto *stopOp : stopOps) + setMixUseRecursively(stopOp, /*applyRoot=*/false); + LLVM_DEBUG({ + os << "[UseAnalysis] After traceback of stopOp, funcOp is " << *funcOp + << "\n"; + }); + // Modify this op. + LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + op->removeAttr("MetaUse"); + } + if (op->hasAttr(ConverterUtils::discreteAttrName)) + setMixUseRecursively(op); + if (auto loopOp = dyn_cast(op)) { + postProcessLoopOp(loopOp, solver); + } else if (auto ifOp = dyn_cast(op)) { + SmallVector yields(ifOp.thenYield().getOperands()); + if (!ifOp.getElseRegion().empty()) + yields.append(llvm::to_vector(ifOp.elseYield().getOperands())); + for (auto yield : yields) { + if (auto *defOp = yield.getDefiningOp()) + setMixUseRecursively(defOp); + } + } + }); + // Remove MetaUse in case of MixUse existing in the op + funcOp.walk([&](Operation *op) { + if (isMetaUse(op) && isMixUse(op)) { + op->removeAttr("MetaUse"); + } + }); + LLVM_DEBUG({ + os << "[UseAnalysis] After post-process, funcOp is " << *funcOp << "\n"; + }); + return success(); +} + +MetaUseEraser::MetaUseEraser(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, context) {} + +LogicalResult MetaUseEraser::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + LLVM_DEBUG({ + int64_t count = 0; + for (auto result : op->getResults()) { + count += std::distance(result.use_begin(), result.use_end()); + } + llvm::dbgs() << "Number of user: " << count << "\n"; + }); + if (isa(op)) { + return rewriter.notifyMatchFailure(op, + "AddPtrOp will be handled separately"); + } + if (isMetaUse(op)) { + rewriter.eraseOp(op); + return success(); + } + return rewriter.notifyMatchFailure(op, "requires meta ops"); +} diff --git a/third_party/ascend/triton-adapter/lib/TritonToUnstructure/BubbleUpOperation.cpp b/third_party/ascend/triton-adapter/lib/TritonToUnstructure/BubbleUpOperation.cpp new file mode 100644 index 000000000..949849e1f --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToUnstructure/BubbleUpOperation.cpp @@ -0,0 +1,309 @@ +#include "TritonToUnstructure/BubbleUpOperation.h" +#include "Utils/Utils.h" + +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#define DEBUG_TYPE "triton-bubble-up-operation" + +BubbleUpExtract::BubbleUpExtract(MLIRContext *context, + bool enableAggressiveMode) + : OpRewritePattern(context), + enableAggressiveMode(enableAggressiveMode) {} + +LogicalResult +BubbleUpExtract::matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const { + auto tensorValue = op.getTensor(); + auto parentOp = tensorValue.getDefiningOp(); + auto indices = + SmallVector(op.getIndices().begin(), op.getIndices().end()); + auto loc = op.getLoc(); + + if (!parentOp || + (!enableAggressiveMode && !parentOp->hasOneUse())) { + return failure(); + } + + if (auto extsiOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, extsiOp, indices, loc, rewriter); + } else if (auto addIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, addIOp, indices, loc, rewriter); + } else if (auto subIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, subIOp, indices, loc, rewriter); + } else if (auto mulIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, mulIOp, indices, loc, rewriter); + } else if (auto divSIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, divSIOp, indices, loc, rewriter); + } else if (auto remSIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, remSIOp, indices, loc, rewriter); + } else if (auto maxSIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, maxSIOp, indices, loc, rewriter); + } else if (auto minSIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, minSIOp, indices, loc, rewriter); + } else if (auto andIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, andIOp, indices, loc, rewriter); + } else if (auto orIOp = dyn_cast(parentOp)) { + bubbleUpIntBinaryOp(op, orIOp, indices, loc, rewriter); + } else if (auto cmpIOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, cmpIOp, indices, loc, rewriter); + } else if (auto truncFOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, truncFOp, indices, loc, rewriter); + } else if (auto extFOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, extFOp, indices, loc, rewriter); + } else if (auto fpTosiOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, fpTosiOp, indices, loc, rewriter); + } else if (auto siTofpOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, siTofpOp, indices, loc, rewriter); + } else if (auto clampFOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, clampFOp, indices, loc, rewriter); + } else if (auto addFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, addFOp, indices, loc, rewriter); + } else if (auto subFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, subFOp, indices, loc, rewriter); + } else if (auto mulFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, mulFOp, indices, loc, rewriter); + } else if (auto divFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, divFOp, indices, loc, rewriter); + } else if (auto minNumFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, minNumFOp, indices, loc, + rewriter); + } else if (auto maxNumFOp = dyn_cast(parentOp)) { + bubbleUpFloatBinaryOp(op, maxNumFOp, indices, loc, + rewriter); + } else if (auto cmpFOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, cmpFOp, indices, loc, rewriter); + } else if (auto broadCastOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, broadCastOp, indices, loc, rewriter); + } else if (auto expandDimsOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, expandDimsOp, indices, loc, + rewriter); + } else if (auto splatOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, splatOp, indices, loc, rewriter); + } else if (auto makeRangeOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, makeRangeOp, indices, loc, + rewriter); + } else if (auto floorOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, floorOp, indices, loc, rewriter); + } else if (auto ceilOp = dyn_cast(parentOp)) { + bubbleUpOperation(op, ceilOp, indices, loc, rewriter); + } else { + return failure(); + } + if (parentOp->use_empty()) + rewriter.eraseOp(parentOp); + + return success(); +} + +Value BubbleUpExtract::createExtractOp(Value value, ArrayRef indices, + Location loc, + PatternRewriter &rewriter) const { + auto extractedOp = rewriter.create(loc, value, indices); + extractedOp->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + return extractedOp; +} + +template +void BubbleUpExtract::bubbleUpIntBinaryOp(Operation *op, BinOpTy binOp, + ArrayRef indices, Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(binOp.getLhs(), indices, loc, rewriter); + auto rhs = createExtractOp(binOp.getRhs(), indices, loc, rewriter); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Binary\n" << *op << '\n' << binOp << '\n'; + }); + rewriter.replaceOpWithNewOp(op, lhs, rhs); +} + +template +void BubbleUpExtract::bubbleUpFloatBinaryOp(Operation *op, BinOpTy binOp, + ArrayRef indices, + Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(binOp.getLhs(), indices, loc, rewriter); + auto rhs = createExtractOp(binOp.getRhs(), indices, loc, rewriter); + rewriter.replaceOpWithNewOp(op, lhs, rhs); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, arith::ExtSIOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); + auto resultType = cast(parentOp.getOut().getType()); + rewriter.replaceOpWithNewOp(op, resultType.getElementType(), + in); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, arith::CmpIOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto lhs = createExtractOp(parentOp.getLhs(), indices, loc, rewriter); + auto rhs = createExtractOp(parentOp.getRhs(), indices, loc, rewriter); + rewriter.replaceOpWithNewOp(op, parentOp.getPredicateAttr(), + lhs, rhs); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, triton::BroadcastOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + auto srcShape = cast(src.getType()).getShape(); + SmallVector newIndices; + for (const auto [index, shape] : llvm::zip_equal(indices, srcShape)) { + if (shape == 1) { + newIndices.push_back( + rewriter.create(loc, rewriter.getIndexAttr(0))); + } else { + newIndices.push_back(index); + } + } + auto extractedOp = createExtractOp(src, newIndices, loc, rewriter); + rewriter.replaceOp(op, extractedOp); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, triton::ExpandDimsOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + SmallVector newIndices; + for (const auto index : llvm::enumerate(indices)) { + if (index.index() != parentOp.getAxis()) { + newIndices.push_back(index.value()); + } + } + auto extractedOp = createExtractOp(src, newIndices, loc, rewriter); + rewriter.replaceOp(op, extractedOp); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, triton::SplatOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + rewriter.replaceOp(op, src); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, triton::MakeRangeOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto resultType = cast(parentOp.getResult().getType()); + rewriter.replaceOpWithNewOp( + op, resultType.getElementType(), indices[0]); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, arith::TruncFOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); + auto resultType = cast(parentOp.getOut().getType()); + rewriter.replaceOpWithNewOp(op, resultType.getElementType(), + in); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, arith::ExtFOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); + auto resultType = cast(parentOp.getOut().getType()); + rewriter.replaceOpWithNewOp(op, resultType.getElementType(), + in); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, arith::FPToSIOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); + auto resultType = cast(parentOp.getOut().getType()); + rewriter.replaceOpWithNewOp(op, resultType.getElementType(), + in); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, arith::SIToFPOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); + auto outType = + cast(parentOp.getOut().getType()).getElementType(); + rewriter.replaceOpWithNewOp(op, outType, in); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, triton::ClampFOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto x = createExtractOp(parentOp.getX(), indices, loc, rewriter); + auto min = createExtractOp(parentOp.getMin(), indices, loc, rewriter); + auto max = createExtractOp(parentOp.getMax(), indices, loc, rewriter); + rewriter.replaceOpWithNewOp(op, x, min, max, + parentOp.getPropagateNan()); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, arith::CmpFOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto lhs = createExtractOp(parentOp.getLhs(), indices, loc, rewriter); + auto rhs = createExtractOp(parentOp.getRhs(), indices, loc, rewriter); + rewriter.replaceOpWithNewOp(op, parentOp.getPredicateAttr(), + lhs, rhs); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, math::FloorOp parentOp, ArrayRef indices, + Location loc, PatternRewriter &rewriter) const { + auto operand = createExtractOp(parentOp.getOperand(), indices, loc, rewriter); + rewriter.replaceOpWithNewOp(op, operand, + parentOp.getFastmath()); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + Operation *op, math::CeilOp parentOp, ArrayRef indices, Location loc, + PatternRewriter &rewriter) const { + auto operand = createExtractOp(parentOp.getOperand(), indices, loc, rewriter); + rewriter.replaceOpWithNewOp(op, operand, + parentOp.getFastmath()); +} + +BubbleUpOperationPass::BubbleUpOperationPass( + const BubbleUpOperationOptions &options) + : BubbleUpOperationBase(options) {} + +void BubbleUpOperationPass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx, enableAggressiveMode); + + if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) { + moduleOp->emitError("failed to apply Patterns"); + signalPassFailure(); + } + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } +} + +std::unique_ptr> +triton::createBubbleUpOperationPass(const BubbleUpOperationOptions &options) { + return std::make_unique(options); +} \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToUnstructure/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/TritonToUnstructure/CMakeLists.txt new file mode 100644 index 000000000..802f5130b --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToUnstructure/CMakeLists.txt @@ -0,0 +1,20 @@ +add_triton_library(TritonToUnstructure + UnstructureConversionPass.cpp + OffsetAnalysis.cpp + BubbleUpOperation.cpp + + DEPENDS + TritonToUnstructureConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonAnalysis + MLIRSCFTransforms +) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp b/third_party/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp new file mode 100644 index 000000000..7d5a6a5d9 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp @@ -0,0 +1,1055 @@ +#include "TritonToUnstructure/OffsetAnalysis.h" +#include "Utils/Utils.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-offset-analysis" + +namespace mlir { +namespace triton { + +PtrOffsetInfo::PtrOffsetInfo() : ptr(nullptr), offset(nullptr) {} + +PtrOffsetInfo::PtrOffsetInfo(const PtrOffsetInfo &other) { + *this = other; +} + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr) : ptr(ptr) { + setZeroOffset(); +} + +PtrOffsetInfo::PtrOffsetInfo(ArrayRef structured) : ptr(nullptr), offset(nullptr) { + setStructured(structured); +} + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr, bool structured) : ptr(ptr) { + setZeroOffset(); + if (auto tensorType = dyn_cast(ptr.getType())) + this->structured.resize(tensorType.getRank(), structured); +} + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr, ArrayRef structured) : ptr(ptr) { + setStructured(structured); +} + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr, const Value &offset, bool structured) : ptr(ptr), offset(offset) { + if (auto tensorType = dyn_cast(ptr.getType())) + this->structured.resize(tensorType.getRank(), structured); +} + +PtrOffsetInfo::PtrOffsetInfo(const Value &ptr, const Value &offset, ArrayRef structured) : ptr(ptr), offset(offset) { + setStructured(structured); +} + +PtrOffsetInfo &PtrOffsetInfo::operator=(const PtrOffsetInfo &other) { + setPtr(other.getPtr()); + setOffset(other.getOffset()); + setStructured(other.getStructured()); + setScalarLike(other.isScalarLike()); + setNegativeFlag(other.isNegativeFlag()); + return *this; +} + +Value PtrOffsetInfo::getPtr() const { return this->ptr; } +Value PtrOffsetInfo::getOffset() const { return this->offset; } +bool PtrOffsetInfo::isScalarLike() const { return this->scalarLike; } +bool PtrOffsetInfo::isNegativeFlag() const { return this->negativeFlag; } + +SmallVector &PtrOffsetInfo::getStructuredRef() { return this->structured; } +const SmallVector &PtrOffsetInfo::getStructured() const { + return this->structured; +} + +int PtrOffsetInfo::getRank() const { + return structured.size(); +} + +void PtrOffsetInfo::setPtr(const Value &ptr) { this->ptr = ptr; } +void PtrOffsetInfo::setOffset(const Value &offset) { this->offset = offset; } + +void PtrOffsetInfo::setStructured() { + assert(ptr && "ptr Should be to infer rank"); + this->structured.clear(); + if (auto tensorType = dyn_cast(ptr.getType())) + this->structured.resize(tensorType.getRank(), true); +} + +void PtrOffsetInfo::setStructured(int rank) { + this->structured.clear(); + this->structured.resize(rank, true); +} + +void PtrOffsetInfo::setUnstructured() { + assert(ptr && "ptr Should be to infer rank"); + this->structured.clear(); + if (auto tensorType = dyn_cast(ptr.getType())) + this->structured.resize(tensorType.getRank(), false); +} + +void PtrOffsetInfo::setUnstructured(int rank) { + this->structured.clear(); + this->structured.resize(rank, false); +} + +void PtrOffsetInfo::setStructured(ArrayRef structured) { + this->structured.resize(structured.size()); + for (size_t i = 0; i < structured.size(); i++) + this->structured[i] = structured[i]; +} + +void PtrOffsetInfo::setStructured(const PtrOffsetInfo &other) { + this->setStructured(other.getStructured()); +} + +void PtrOffsetInfo::setNegativeFlag(bool negativeFlag) { + this->negativeFlag = negativeFlag; +} + +void PtrOffsetInfo::setScalarLike(bool scalarLike) { + this->scalarLike = scalarLike; +} + +bool PtrOffsetInfo::isStructured(int dim) const { + return this->scalarLike || structured[dim]; +} + +bool PtrOffsetInfo::isStructured() const { + return this->scalarLike || + llvm::all_of(structured, [](auto dim) { return dim; }); +} + +bool PtrOffsetInfo::isUnstructured() const { + return llvm::all_of(structured, [](auto dim) { return !dim; }); +} + +void PtrOffsetInfo::setZeroOffset() { + if (!ptr) + return; + Value offset; + OpBuilder builder(ptr.getContext()); + builder.setInsertionPointToStart(ptr.getParentBlock()); + if (auto tensorType = dyn_cast(ptr.getType())) { + offset = builder.create( + ptr.getLoc(), DenseElementsAttr::get( + RankedTensorType::get(tensorType.getShape(), + builder.getIntegerType(64)), + builder.getZeroAttr(builder.getIntegerType(64)))); + } else { + offset = builder.create( + ptr.getLoc(), builder.getI64IntegerAttr(0)); + } + setOffset(offset); +} + +PtrOffsetInfo combineInfo(const PtrOffsetInfo &lhs, const PtrOffsetInfo &rhs) { + PtrOffsetInfo info; + assert(lhs.getRank() == rhs.getRank() && + "Rank must be same to be combined"); + + info.setScalarLike(lhs.isScalarLike() && + rhs.isScalarLike()); + SmallVector &structuredRef = info.getStructuredRef(); + structuredRef.resize(lhs.getRank()); + for (size_t i = 0; i < structuredRef.size(); i++) + structuredRef[i] = lhs.isStructured(i) && rhs.isStructured(i); + info.setNegativeFlag(lhs.isNegativeFlag() || rhs.isNegativeFlag()); + return info; +} + +void parse(Value operand, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + if (offsetMap.contains(operand)) { + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "found\n" << operand << '\n'; + }); + return; + } + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "parse\n" << operand << '\n'; + }); + + if (auto *defOp = operand.getDefiningOp()) { + if (isa(defOp->getDialect())) { + parseArithOp(defOp, loc, rewriter, offsetMap); + } else if (isa(defOp->getDialect())) { + parseTritonOp(defOp, loc, rewriter, offsetMap); + } else { + if (auto ifOp = dyn_cast(defOp)) { + parseIf(ifOp, loc, rewriter, offsetMap, operand); + } else if (auto yieldOp = dyn_cast(defOp)) { + parseYield(yieldOp, loc, rewriter, offsetMap); + } else if (auto loopOp = dyn_cast(defOp)) { + parseLoopOp(loopOp, loc, rewriter, offsetMap, operand); + } else if (auto extractOp = dyn_cast(defOp)) { + parseExtract(extractOp, loc, rewriter, offsetMap); + } + } + } else if (auto ptrType = dyn_cast(operand.getType())) { + offsetMap[operand] = PtrOffsetInfo(operand, true); + } else if (auto blockArgument = dyn_cast(operand)) { + auto parentOp = blockArgument.getOwner()->getParentOp(); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Handling block argument\n" << *parentOp << '\n'; + }); + if (isa(parentOp)) { + offsetMap[operand] = PtrOffsetInfo(); + } else if (auto loopOp = dyn_cast(parentOp)) { + parseLoopRegionIterArg(loopOp, loc, rewriter, offsetMap, blockArgument); + } + } else { + llvm_unreachable("Unreachable"); + } + + if (!offsetMap.contains(operand)) { + offsetMap[operand] = PtrOffsetInfo(); + if (auto tensorType = dyn_cast(operand.getType())) + offsetMap[operand].setUnstructured(tensorType.getRank()); + } + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "finish parse\n" << operand << '\n'; + auto data = offsetMap.at(operand); + for (auto s : data.getStructuredRef()) + os << s; + os << "\n"; + os << "FNparse: " << operand << " ,isNegativeFlag: " << data.isNegativeFlag() << "\n"; + }); +} + +void parseLoopRegionIterArg(LoopLikeOpInterface loopOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap, + BlockArgument regionIterArg) { + auto regionIterArgInfo = PtrOffsetInfo(regionIterArg); + OpOperand *initArgOperand = loopOp.getTiedLoopInit(regionIterArg); + if (!initArgOperand) + return; + Value initArg = initArgOperand->get(); + parse(initArg, loc, rewriter, offsetMap); + regionIterArgInfo.setStructured(offsetMap[initArg]); + offsetMap[regionIterArg] = regionIterArgInfo; +} + +void parseArithOp(Operation *arithOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + assert(isa(arithOp->getDialect())); + if (auto addIOp = dyn_cast(arithOp)) { + parseAddI(addIOp, loc, rewriter, offsetMap); + } else if (auto subIOp = dyn_cast(arithOp)) { + parseBinaryOp(subIOp, loc, rewriter, offsetMap); + } else if (auto indexCastOp = dyn_cast(arithOp)) { + parseIndexCast(indexCastOp, loc, rewriter, offsetMap); + } else if (auto constantFloatOp = dyn_cast(arithOp)) { + parseConstantOp(constantFloatOp, loc, rewriter, offsetMap); + } else if (auto constantIntOp = dyn_cast(arithOp)) { + parseConstantOp(constantIntOp, loc, rewriter, offsetMap); + } else if (auto constantOp = dyn_cast(arithOp)) { + parseConstantOp(constantOp, loc, rewriter, offsetMap); + } else if (auto extSIOp = dyn_cast(arithOp)) { + parseExtSI(extSIOp, loc, rewriter, offsetMap); + } else if (auto mulIOp = dyn_cast(arithOp)) { + parseMulI(mulIOp, loc, rewriter, offsetMap); + } else if (auto remSIOp = dyn_cast(arithOp)) { + parseBinaryOp(remSIOp, loc, rewriter, offsetMap); + } else if (auto divSIOp = dyn_cast(arithOp)) { + parseBinaryOp(divSIOp, loc, rewriter, offsetMap); + } else if (auto selectOp = dyn_cast(arithOp)) { + parseSelect(selectOp, loc, rewriter, offsetMap); + } else if (auto fPToSIOp = dyn_cast(arithOp)) { + parseFPToSI(fPToSIOp, loc, rewriter, offsetMap); + } else if (auto sIToFPOp = dyn_cast(arithOp)) { + parseSIToFP(sIToFPOp, loc, rewriter, offsetMap); + } else if (auto mulFOp = dyn_cast(arithOp)) { + parseBinaryOp(mulFOp, loc, rewriter, offsetMap); + } else if (auto divFOp = dyn_cast(arithOp)) { + parseBinaryOp(divFOp, loc, rewriter, offsetMap); + } else if (auto addFOp = dyn_cast(arithOp)) { + parseBinaryOp(addFOp, loc, rewriter, offsetMap); + } else if (auto subFOp = dyn_cast(arithOp)) { + parseBinaryOp(subFOp, loc, rewriter, offsetMap); + } else if (auto minNumFOp = dyn_cast(arithOp)) { + parseBinaryOp(minNumFOp, loc, rewriter, offsetMap); + } else if (auto maxNumFOp = dyn_cast(arithOp)) { + parseBinaryOp(maxNumFOp, loc, rewriter, offsetMap); + } else if (auto maxSIOp = dyn_cast(arithOp)) { + parseBinaryOp(maxSIOp, loc, rewriter, offsetMap); + } else if (auto minSIOp = dyn_cast(arithOp)) { + parseBinaryOp(minSIOp, loc, rewriter, offsetMap); + } else if (auto cmpIOp = dyn_cast(arithOp)) { + parseBinaryOp(cmpIOp, loc, rewriter, offsetMap); + } else if (auto andIOp = dyn_cast(arithOp)) { + parseBinaryOp(andIOp, loc, rewriter, offsetMap); + } else if (auto orIOp = dyn_cast(arithOp)) { + parseBinaryOp(orIOp, loc, rewriter, offsetMap); + } +} + +void parseTritonOp(Operation *tritonOp, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + assert(isa(tritonOp->getDialect())); + if (auto addPtrOp = dyn_cast(tritonOp)) { + parseAddPtr(addPtrOp, loc, rewriter, offsetMap); + } else if (auto splatOp = dyn_cast(tritonOp)) { + parseSplat(splatOp, loc, rewriter, offsetMap); + } else if (auto getProgramIdOp = dyn_cast(tritonOp)) { + parseConstantOp(getProgramIdOp, loc, rewriter, offsetMap); + } else if (auto getNumProgramsOp = + dyn_cast(tritonOp)) { + parseConstantOp(getNumProgramsOp, loc, rewriter, offsetMap); + } else if (auto makeRangeOp = dyn_cast(tritonOp)) { + parseMakeRange(makeRangeOp, loc, rewriter, offsetMap); + } else if (auto bitcastOp = dyn_cast(tritonOp)) { + parseBitcast(bitcastOp, loc, rewriter, offsetMap); + } else if (auto loadOp = dyn_cast(tritonOp)) { + parseLoad(loadOp, loc, rewriter, offsetMap); + } else if (auto broadcastOp = dyn_cast(tritonOp)) { + parseBroadcast(broadcastOp, loc, rewriter, offsetMap); + } else if (auto expandDimsOp = dyn_cast(tritonOp)) { + parseExpandDims(expandDimsOp, loc, rewriter, offsetMap); + } else if (auto clampFOp = dyn_cast(tritonOp)) { + parseClampF(clampFOp, loc, rewriter, offsetMap); + } else if (auto makeTensorDescOp = + dyn_cast(tritonOp)) { + parseMakeTensorDesc(makeTensorDescOp, loc, rewriter, offsetMap); + } else if (auto makeTensorPtrOp = + dyn_cast(tritonOp)) { + parseMakeTensorPtr(makeTensorPtrOp, loc, rewriter, offsetMap); + } else if (auto reduceOp = dyn_cast(tritonOp)) { + parseReduce(reduceOp, loc, rewriter, offsetMap); + } else if (auto reduceReturnOp = dyn_cast(tritonOp)) { + parseReduceReturn(reduceReturnOp, loc, rewriter, offsetMap); + } else if (auto advanceOp = dyn_cast(tritonOp)) { + parseAdvance(advanceOp, loc, rewriter, offsetMap); + } else if (auto intToPtrOp = dyn_cast(tritonOp)) { + parseIntToPtr(intToPtrOp, loc, rewriter, offsetMap); + } +} + +void parseAddPtr(triton::AddPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get addPtr base_ptr + Value ptr = op.getPtr(); + parse(ptr, op.getLoc(), rewriter, offsetMap); + // Get addPtr offset + Value offsetValue = op.getOffset(); + parse(offsetValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo ptrOffsetInfo = offsetMap.at(ptr); + PtrOffsetInfo offsetOffsetInfo = offsetMap.at(offsetValue); + // Modify IR + + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto offsetType = dyn_cast(offsetValue.getType())) { + auto offsetElementType = cast(offsetType.getElementType()); + if (offsetElementType.getWidth() != 64) { + auto newOffsetType = RankedTensorType::get(offsetType.getShape(), + rewriter.getIntegerType(64)); + offsetValue = rewriter.create(op.getLoc(), newOffsetType, + offsetValue); + } + } else { + auto offsetIntType = cast(offsetValue.getType()); + if (offsetIntType.getWidth() != 64) { + offsetValue = rewriter.create( + op.getLoc(), rewriter.getIntegerType(64), offsetValue); + } + } + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "[parseAddPtr] Adding offset\n"; + os << ptrOffsetInfo.getOffset() << '\n' << offsetValue << '\n'; + }); + Value offset = rewriter.create( + op.getLoc(), ptrOffsetInfo.getOffset(), offsetValue); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "[parseAddPtr] offset is\n" << offset << '\n'; + }); + // Set addPtr offset map + auto dst = op.getResult(); + auto dstOffsetInfo = combineInfo(ptrOffsetInfo, offsetOffsetInfo); + dstOffsetInfo.setPtr(ptrOffsetInfo.getPtr()); + dstOffsetInfo.setOffset(offset); + offsetMap[dst] = dstOffsetInfo; + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + SmallVector &ptrStructured = ptrOffsetInfo.getStructuredRef(); + SmallVector &offsetStructured = offsetOffsetInfo.getStructuredRef(); + os << "[parseAddPtr] ptrStructured: "; + for (size_t i = 0; i < ptrStructured.size(); i++) + os << ptrStructured[i]; + os << "\n"; + os << "[parseAddPtr] offsetStructured: "; + for (size_t i = 0; i < offsetStructured.size(); i++) + os << offsetStructured[i]; + os << "\n"; + os << "[parseAddPtr] offsetOffsetInfo.isNegativeFlag(): "; + os << offsetOffsetInfo.isNegativeFlag(); + os << "\n"; + os << "[parseAddPtr] ptrOffsetInfo.isNegativeFlag(): "; + os << ptrOffsetInfo.isNegativeFlag(); + os << "\n"; + }); +} + +void parseSplat(triton::SplatOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get splat src + auto src = op.getSrc(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + auto dst = op.getResult(); + auto dstType = cast(dst.getType()); + PtrOffsetInfo dstOffsetInfo(srcOffsetInfo.getPtr()); + // Modify IR + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "[parseSplat] dst is\n" << dst << '\n'; + }); + if (isa(dstType.getElementType())) { + RewriterBase::InsertionGuard guard(rewriter); + auto dstShape = dstType.getShape(); + rewriter.setInsertionPoint(op); + Value valueOffset = srcOffsetInfo.getOffset(); + Value offset = rewriter.create( + loc, RankedTensorType::get(dstShape, rewriter.getIntegerType(64)), + valueOffset); + dstOffsetInfo.setOffset(offset); + } + // Set addPtr offset map + + dstOffsetInfo.setStructured(dstType.getRank()); + dstOffsetInfo.setScalarLike(true); + dstOffsetInfo.setNegativeFlag(srcOffsetInfo.isNegativeFlag()); + offsetMap[dst] = dstOffsetInfo; +} + +template +void parseBinaryOp(BinOpTy op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + auto lhs = op.getLhs(); + parse(lhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo lhsOffsetInfo = offsetMap.at(lhs); + SmallVector &lhsStructured = lhsOffsetInfo.getStructuredRef(); + auto rhs = op.getRhs(); + parse(rhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo rhsOffsetInfo = offsetMap.at(rhs); + SmallVector &rhsStructured = rhsOffsetInfo.getStructuredRef(); + auto dst = op->getResult(0); + PtrOffsetInfo dstOffsetInfo; + dstOffsetInfo.setScalarLike(lhsOffsetInfo.isScalarLike() && + rhsOffsetInfo.isScalarLike()); + if (dstOffsetInfo.isScalarLike()) + dstOffsetInfo.setStructured(lhsStructured.size()); + else + dstOffsetInfo.setUnstructured(lhsStructured.size()); + + if (isa(op.getOperation())) { + dstOffsetInfo.setNegativeFlag(true); + } else { + dstOffsetInfo.setNegativeFlag(lhsOffsetInfo.isNegativeFlag() || + rhsOffsetInfo.isNegativeFlag()); + } + offsetMap[dst] = dstOffsetInfo; +} + +void parseAddI(arith::AddIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get addi lhs + auto lhs = op.getLhs(); + parse(lhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo lhsOffsetInfo = offsetMap.at(lhs); + // Get addi rhs + auto rhs = op.getRhs(); + parse(rhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo rhsOffsetInfo = offsetMap.at(rhs); + // Set addi offset map + auto dst = op.getResult(); + offsetMap[dst] = combineInfo(lhsOffsetInfo, rhsOffsetInfo); +} + +void parseIndexCast(arith::IndexCastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get indexCast input + auto src = op.getIn(); + parse(src, op.getLoc(), rewriter, offsetMap); + // Set indexCast offset map + auto dst = op.getOut(); + offsetMap[dst] = offsetMap.at(src); +} + +template +bool isConstantNegative(AttrTy attr, TypeTy type) { + if constexpr (std::is_same_v && + std::is_same_v) { + return attr.getInt() < 0; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return attr.getValueAsDouble() < 0.0; + } else if constexpr(std::is_same_v && + std::is_same_v) { + return attr.getInt() < 0; + } else if constexpr (std::is_base_of_v && + std::is_base_of_v) { + auto tensorType = mlir::cast(type); + auto elemType = tensorType.getElementType(); + + if (auto denseIntAttr = dyn_cast(attr)) { + if (auto intElemType = dyn_cast(elemType)) { + for (auto elemVal : denseIntAttr.template getValues()) { + auto elemAttr = mlir::IntegerAttr::get(intElemType, elemVal); + if (isConstantNegative(elemAttr, intElemType)) { + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO: Tensor has negative element: " << elemAttr << "\n"; + }); + return true; + } + } + return false; + } + } + + else if (auto denseFloatAttr = dyn_cast(attr)) { + if (auto floatElemType = dyn_cast(elemType)) { + for (auto elemVal : denseFloatAttr.template getValues()) { + auto elemAttr = mlir::FloatAttr::get(floatElemType, elemVal); + if (isConstantNegative(elemAttr, floatElemType)) { + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO: Tensor has negative element: " << elemAttr << "\n"; + }); + return true; + } + } + return false; + } + } + + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO: Unsupported tensor elemType: " << elemType + << ",tensorType:" << tensorType << "\n"; + }); + return false; + } else { + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO, Unsupported: attr: " << attr + << ", type: " << type << " \n"; + }); + return false; + } +} + +template +void parseConstantOp(ConstOpTy dst, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + mlir::Operation *opPtr = nullptr; + if constexpr (std::is_pointer_v) { + if (dst != nullptr) { + opPtr = dst->getOperation(); + } + } else { + opPtr = dst.getOperation(); + } + + mlir::Value opResult = opPtr->getResult(0); + + offsetMap[opResult] = PtrOffsetInfo(); + offsetMap[opResult].setScalarLike(true); + if (auto tensorType = mlir::dyn_cast(opResult.getType())) { + offsetMap[opResult].setStructured(tensorType.getRank()); + } + + auto constantOp = mlir::dyn_cast(opPtr); + if (!constantOp) { + LLVM_DEBUG({ + llvm::dbgs() << "Warning: Non-ConstantOp (" << opPtr->getName() + << ") passed to parseConstantOp\n"; + }); + return; + } + + mlir::Attribute constAttr = constantOp.getValue(); + mlir::Type resultType = opResult.getType(); + + if (auto intType = dyn_cast(resultType)) { + if (auto intAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(intAttr, intType)); + } + } else if (auto floatType = dyn_cast(resultType)) { + if (auto floatAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(floatAttr, floatType)); + } + } else if (auto indexType = dyn_cast(resultType)) { + if (auto intAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(intAttr, indexType)); + } + } else if (auto indexType = dyn_cast(resultType)) { + if (auto intAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(intAttr, indexType)); + } + } else { + llvm_unreachable("PCO: Non-ConstantOp passed to parseConstantOp \n"); + } +} + +void parseMakeRange(triton::MakeRangeOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Set makeRange offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setStructured(1); +} + +void parseExtSI(arith::ExtSIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get extSI input + auto src = op.getIn(); + parse(src, op.getLoc(), rewriter, offsetMap); + // Set extSI offset map + auto dst = op.getOut(); + offsetMap[dst] = offsetMap.at(src); +} + +void parseBitcast(triton::BitcastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get bitcast src + auto src = op.getSrc(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Set extSI offset map + auto dst = op.getResult(); + if (auto ptr = srcOffsetInfo.getPtr()) { + Type ptrType = dst.getType(); + if (auto tensorType = dyn_cast(ptrType)) + ptrType = tensorType.getElementType(); + rewriter.setInsertionPoint(op); + ptr = rewriter.create(loc, ptrType, ptr); + offsetMap[dst] = PtrOffsetInfo(ptr, srcOffsetInfo.getOffset(), srcStructured); + } else { + offsetMap[dst] = PtrOffsetInfo(srcStructured); + } + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); +} + +void parseLoad(triton::LoadOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get load ptr + auto ptr = op.getPtr(); + parse(ptr, op.getLoc(), rewriter, offsetMap); + // Set load offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(offsetMap[ptr].isScalarLike()); + offsetMap[dst].setNegativeFlag(offsetMap[ptr].isNegativeFlag()); + auto tensorType = dyn_cast(dst.getType()); + if (!tensorType) + return; + offsetMap[dst].setUnstructured(tensorType.getRank()); +} + +void parseMulI(arith::MulIOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get muli lhs + auto lhs = op.getLhs(); + parse(lhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo lhsOffsetInfo = offsetMap.at(lhs); + SmallVector &lhsStructured = lhsOffsetInfo.getStructuredRef(); + bool lhsScalarLike = lhsOffsetInfo.isScalarLike(); + // Get muli rhs + auto rhs = op.getRhs(); + parse(rhs, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo rhsOffsetInfo = offsetMap.at(rhs); + SmallVector &rhsStructured = rhsOffsetInfo.getStructuredRef(); + bool rhsScalarLike = rhsOffsetInfo.isScalarLike(); + // Set muli offset map + size_t maxSize = std::max(lhsStructured.size(), rhsStructured.size()); + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(lhsScalarLike && rhsScalarLike); + offsetMap[dst].setNegativeFlag(lhsOffsetInfo.isNegativeFlag() + || rhsOffsetInfo.isNegativeFlag()); + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + dstStructured.resize(maxSize); + for (size_t i = 0; i < maxSize; i++) + if (lhsScalarLike) + dstStructured[i] = rhsStructured[i]; + else if (rhsScalarLike) + dstStructured[i] = lhsStructured[i]; + else + dstStructured[i] = false; +} + +void parseBroadcast(triton::BroadcastOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get broadcast src + auto src = op.getSrcMutable().get(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Get broadcast dim + auto dst = op.getResult(); + assert(isa(src.getType()) && + "tt.broadcast's input should be a tensor"); + auto srcType = cast(src.getType()); + auto dstType = cast(dst.getType()); + assert(srcType.getRank() == dstType.getRank() && + "rank of source shoule be equal to destnation"); + auto broadcastDim = ConverterUtils::getBroadcastDims(srcType, dstType); + // Set broadcast offset map + offsetMap[dst] = PtrOffsetInfo(srcOffsetInfo.getPtr()); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); + + if (srcOffsetInfo.getPtr()) { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + Value valueOffset = srcOffsetInfo.getOffset(); + Value offset = rewriter.create( + loc, + RankedTensorType::get(dstType.getShape(), rewriter.getIntegerType(64)), + valueOffset); + + offsetMap[dst].setOffset(offset); + } + + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + dstStructured.resize(srcStructured.size()); + for (size_t i = 0; i < dstStructured.size(); i++) + if (llvm::find(broadcastDim, i) != broadcastDim.end()) + dstStructured[i] = true; + else + dstStructured[i] = srcStructured[i]; +} + +void parseExpandDims(triton::ExpandDimsOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get expandDims src + auto src = op.getSrcMutable().get(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Set expandDims offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(srcOffsetInfo.getPtr()); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); + if (srcOffsetInfo.getPtr()) { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + Value valueOffset = srcOffsetInfo.getOffset(); + Value offset = rewriter.create(loc, valueOffset, + op.getAxisAttr()); + + offsetMap[dst].setOffset(offset); + } + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + dstStructured.resize(srcStructured.size() + 1); + size_t j = 0; + for (size_t i = 0; i < dstStructured.size(); i++) + if (i == op.getAxis()) { + dstStructured[i] = true; + } else { + dstStructured[i] = srcStructured[j]; + j++; + } +} + +void parseClampF(triton::ClampFOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get clampF src + auto src = op.getX(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + // Get clampF min + auto clampMin = op.getX(); + parse(clampMin, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo minOffsetInfo = offsetMap.at(clampMin); + // Get clampF max + auto clampMax = op.getX(); + parse(clampMax, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo maxOffsetInfo = offsetMap.at(clampMax); + // Set clampF offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike() && + minOffsetInfo.isScalarLike() && + maxOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag() || + minOffsetInfo.isNegativeFlag() || + maxOffsetInfo.isNegativeFlag()); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + offsetMap[dst].setUnstructured(dstType.getRank()); +} + +void parseSelect(arith::SelectOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get select condition + auto condition = op.getCondition(); + parse(condition, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo conditionOffsetInfo = offsetMap.at(condition); + bool conditionScalarLike = conditionOffsetInfo.isScalarLike(); + // Get select trueValue + auto trueValue = op.getTrueValue(); + parse(trueValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo trueValueOffsetInfo = offsetMap.at(trueValue); + SmallVector &trueValueStructured = trueValueOffsetInfo.getStructuredRef(); + bool trueValueScalarLike = trueValueOffsetInfo.isScalarLike(); + // Get select falseValue + auto falseValue = op.getFalseValue(); + parse(falseValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo falseValueOffsetInfo = offsetMap.at(falseValue); + SmallVector &falseValueStructured = falseValueOffsetInfo.getStructuredRef(); + bool falseValueScalarLike = falseValueOffsetInfo.isScalarLike(); + // Set select offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(conditionScalarLike && trueValueScalarLike && + falseValueScalarLike); + auto dstType = dyn_cast(dst.getType()); + offsetMap[dst].setNegativeFlag(trueValueOffsetInfo.isNegativeFlag() || + falseValueOffsetInfo.isNegativeFlag()); + if (!dstType) + return; + if (offsetMap[dst].isScalarLike()) + offsetMap[dst].setStructured(dstType.getRank()); + else + offsetMap[dst].setUnstructured(dstType.getRank()); +} + +void parseFPToSI(arith::FPToSIOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get FPToSI src + auto src = op.getIn(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + // Set FPToSI offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + if (offsetMap[dst].isScalarLike()) + offsetMap[dst].setStructured(dstType.getRank()); + else + offsetMap[dst].setUnstructured(dstType.getRank()); +} + +void parseSIToFP(arith::SIToFPOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get SIToFP src + auto src = op.getIn(); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + // Set SIToFP offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + if (offsetMap[dst].isScalarLike()) + offsetMap[dst].setStructured(dstType.getRank()); + else + offsetMap[dst].setUnstructured(dstType.getRank()); +} + +void parseMakeTensorDesc(triton::MakeTensorDescOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Set MakeTensorDesc offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + offsetMap[dst].setStructured(dstType.getRank()); +} + +void parseMakeTensorPtr(triton::MakeTensorPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Set MakeTensorPtr offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + offsetMap[dst].setStructured(dstType.getRank()); +} + +void parseAdvance(triton::AdvanceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Set Advance offset map + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + auto dstType = dyn_cast(dst.getType()); + if (!dstType) + return; + offsetMap[dst].setStructured(dstType.getRank()); +} + +void parseReduce(triton::ReduceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get reduce src + Value src = op->getOperand(0); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Set reduce offset map + Value dst = op->getResult(0); + auto dstType = dyn_cast(dst.getType()); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); + if (!dstType) + return; + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + auto dstShape = dstType.getShape(); + dstStructured.resize(dstShape.size()); + for (size_t i = 0; i < dstStructured.size(); i++) + if (dstShape[i] == 1) + dstStructured[i] = true; + else + dstStructured[i] = srcStructured[i]; +} + +void parseReduceReturn(triton::ReduceReturnOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get reduce src + Value src = op->getOperand(0); + parse(src, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo srcOffsetInfo = offsetMap.at(src); + SmallVector &srcStructured = srcOffsetInfo.getStructuredRef(); + // Set reduce offset map + Value dst = op->getResult(0); + auto dstType = dyn_cast(dst.getType()); + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); + if (!dstType) + return; + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + auto dstShape = dstType.getShape(); + dstStructured.resize(dstShape.size()); + for (size_t i = 0; i < dstStructured.size(); i++) + if (dstShape[i] == 1) + dstStructured[i] = true; + else + dstStructured[i] = srcStructured[i]; +} + +void parseIf(scf::IfOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap, Value dst) { + const unsigned int index = cast(dst).getResultNumber(); + // Get if then region + Block &thenBlock = op.getThenRegion().front(); + Value thenYieldedValue = thenBlock.getTerminator()->getOperand(index); + parse(thenYieldedValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo thenOffsetInfo = offsetMap.at(thenYieldedValue); + SmallVector &thenStructured = thenOffsetInfo.getStructuredRef(); + // Get if else region + bool dstIsScalar = thenOffsetInfo.isScalarLike(); + SmallVector elseStructured; + if (op.elseBlock()) { + Block &elseBlock = op.getElseRegion().front(); + Value elseYieldedValue = elseBlock.getTerminator()->getOperand(index); + parse(elseYieldedValue, op.getLoc(), rewriter, offsetMap); + PtrOffsetInfo elseOffsetInfo = offsetMap.at(elseYieldedValue); + elseStructured = elseOffsetInfo.getStructuredRef(); + dstIsScalar = dstIsScalar && elseOffsetInfo.isScalarLike(); + } + // Set if offset map + offsetMap[dst] = PtrOffsetInfo(); + offsetMap[dst].setScalarLike(dstIsScalar); + offsetMap[dst].setNegativeFlag(thenOffsetInfo.isNegativeFlag()); + SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); + dstStructured.resize(thenStructured.size()); + for (size_t i = 0; i < dstStructured.size(); i++) + if (op.elseBlock()) + dstStructured[i] = thenStructured[i] && elseStructured[i]; + else + dstStructured[i] = thenStructured[i]; +} + +void parseYield(scf::YieldOp op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get yield src + for (auto src : op->getOperands()) + parse(src, op.getLoc(), rewriter, offsetMap); +} + +void parseLoopOp(LoopLikeOpInterface op, const Location &loc, RewriterBase &rewriter, + llvm::DenseMap &offsetMap, Value dst) { + auto resNum = cast(dst).getResultNumber(); + Value yieldedValue = op.getYieldedValues()[resNum]; + parse(yieldedValue, op.getLoc(), rewriter, offsetMap); + offsetMap[dst] = PtrOffsetInfo() = offsetMap.at(yieldedValue); +} + +void parseExtractSlice(tensor::ExtractSliceOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + // Get extractSlice src + auto src = op.getOperand(0); + parse(src, op.getLoc(), rewriter, offsetMap); + // Set extractSlice offset map + auto dst = op.getResult(); + offsetMap[dst] = offsetMap.at(src); +} + +void parseExtract(tensor::ExtractOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + auto parentValue = op.getTensor(); + parse(parentValue, op.getLoc(), rewriter, offsetMap); + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(); + if (isa(dst.getType())) { + offsetMap[dst].setPtr(dst); + } + offsetMap[dst].setScalarLike(true); +} + +void parseIntToPtr(triton::IntToPtrOp op, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + auto dst = op.getResult(); + offsetMap[dst] = PtrOffsetInfo(dst); + offsetMap[dst].setScalarLike(true); + + parse(op.getSrc(), op.getLoc(), rewriter, offsetMap); + auto srcOffsetInfo = offsetMap.at(op.getSrc()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); +} + +} // namespace triton +} // namespace mlir \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp b/third_party/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp new file mode 100644 index 000000000..4e0cafce3 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp @@ -0,0 +1,503 @@ +#include "TritonToUnstructure/UnstructureConversionPass.h" +#include "Utils/Utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/STLExtras.h" + +#include + +#define DEBUG_TYPE "triton-unstructure-converter" + +using namespace mlir; +using namespace triton; + +#include "llvm/Support/Debug.h" + +template +Value UnstructuredMemAccessConverter::createExtractOp( + Location loc, Value value, ArrayRef iterIdx, + PatternRewriter &rewriter) const { + if (!value) + return value; + auto extractedOp = rewriter.create(loc, value, iterIdx); + extractedOp->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + return extractedOp; +} + +template +MemAccOpTy UnstructuredMemAccessConverter::createMemAccOp( + MemAccOpTy op, Value ptrToAccess, Location loc, ArrayRef iterIdx, + PatternRewriter &rewriter) const { + llvm_unreachable("Unhandled discrete memory access operation"); +} + +template <> +triton::LoadOp UnstructuredMemAccessConverter::createMemAccOp( + triton::LoadOp op, Value ptrToAccess, Location loc, ArrayRef iterIdx, + PatternRewriter &rewriter) const { + return rewriter.create(loc, ptrToAccess, op.getCache(), + op.getEvict(), false); +} + +template <> +triton::AtomicRMWOp +UnstructuredMemAccessConverter::createMemAccOp( + triton::AtomicRMWOp op, Value ptrToAccess, Location loc, + ArrayRef iterIdx, PatternRewriter &rewriter) const { + auto extractedValue = createExtractOp(loc, op.getVal(), iterIdx, rewriter); + auto extractedMask = createExtractOp(loc, op.getMask(), iterIdx, rewriter); + auto resultType = cast(op.getResult().getType()); + SmallVector scalarLikeShape(resultType.getRank(), 1); + auto scalarLikeType = + RankedTensorType::get(scalarLikeShape, resultType.getElementType()); + auto splatedPtrToAccess = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, ptrToAccess.getType()), + ptrToAccess); + auto splatedExtractedValue = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedValue.getType()), + extractedValue); + auto splatedExtractedMask = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedMask.getType()), + extractedMask); + return rewriter.create( + loc, scalarLikeType, op.getAtomicRmwOpAttr(), splatedPtrToAccess, + splatedExtractedValue, splatedExtractedMask, op.getSemAttr(), + op.getScopeAttr()); +} + +template <> +triton::AtomicCASOp +UnstructuredMemAccessConverter::createMemAccOp( + triton::AtomicCASOp op, Value ptrToAccess, Location loc, + ArrayRef iterIdx, PatternRewriter &rewriter) const { + auto extractedCmp = createExtractOp(loc, op.getCmp(), iterIdx, rewriter); + auto extractedValue = createExtractOp(loc, op.getVal(), iterIdx, rewriter); + auto resultType = cast(op.getResult().getType()); + SmallVector scalarLikeShape(resultType.getRank(), 1); + auto scalarLikeType = + RankedTensorType::get(scalarLikeShape, resultType.getElementType()); + auto splatedPtrToAccess = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, ptrToAccess.getType()), + ptrToAccess); + auto splatedExtractedCmp = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedCmp.getType()), + extractedCmp); + auto splatedExtractedValue = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedValue.getType()), + extractedValue); + return rewriter.create( + loc, scalarLikeType, splatedPtrToAccess, splatedExtractedCmp, + splatedExtractedValue, op.getSemAttr(), op.getScopeAttr()); +} + +template <> +triton::StoreOp UnstructuredMemAccessConverter::createMemAccOp( + triton::StoreOp op, Value ptrToAccess, Location loc, + ArrayRef iterIdx, PatternRewriter &rewriter) const { + auto extractedValue = createExtractOp(loc, op.getValue(), iterIdx, rewriter); + auto extractedMask = createExtractOp(loc, op.getMask(), iterIdx, rewriter); + return rewriter.create(loc, ptrToAccess, extractedValue, + extractedMask); +} + +template <> +template <> +void UnstructuredMemAccessConverter::splatAndLoadScenario< + triton::LoadOp>(triton::LoadOp op, int rank, + PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + SmallVector idx( + rank, rewriter.create(loc, rewriter.getIndexAttr(0))); + auto extractedPtr = createExtractOp(loc, op.getPtr(), idx, rewriter); + Value mask = op.getMask(); + Value other = op.getOther(); + Value loadedValue = rewriter.create( + loc, extractedPtr, /*mask=*/nullptr, /*other=*/nullptr, + /*boundaryCheck=*/ArrayRef(), + /*PaddingOptionAttr=*/nullptr); + loadedValue = rewriter.create(loc, op.getResult().getType(), + loadedValue); + if (mask) + rewriter.replaceOpWithNewOp(op, mask, loadedValue, other); + else + rewriter.replaceOp(op, loadedValue); +} + +template +void UnstructuredMemAccessConverter::AddAssertForAddPtr( + MemAccOpTy op, const Value &opoffset, PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto opoffsetType = opoffset.getType(); + Value constantZero; + + op->setAttr("Negative", UnitAttr::get(rewriter.getContext())); + if (auto tensorType = dyn_cast(opoffsetType)) { + constantZero = rewriter.create( + loc, rewriter.getZeroAttr(tensorType)); + } else { + constantZero = rewriter.create( + loc, 0, opoffset.getType()); + } + Value cmpResult = rewriter.create( + loc, arith::CmpIPredicate::sge, opoffset, constantZero); + + mlir::StringAttr assertMsg = rewriter.getStringAttr( + "AddPtr offset (from subi) must be >= 0"); + + rewriter.create(loc, cmpResult, assertMsg); +} + +template +UnstructuredMemAccessConverter::UnstructuredMemAccessConverter( + MLIRContext *context, const llvm::DenseMap &offsetMap) + : OpRewritePattern(context), offsetMap(offsetMap) {} + +template +LogicalResult UnstructuredMemAccessConverter::matchAndRewrite( + MemAccOpTy op, PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + auto ptr = op.getPtr(); + auto ptrType = dyn_cast(ptr.getType()); + + if (!ptrType || op->hasAttr(ConverterUtils::discreteAttrName) + || op->hasAttr("Negative")) { + return failure(); + } + if (!offsetMap.contains(ptr)) + return op.emitError() << "PtrOffsetInfo should be computed\n" << ptr; + + auto ptrOffsetInfo = offsetMap.at(ptr); + bool flag = false; + if (ptrOffsetInfo.isNegativeFlag()) { + flag = true; + } + + if (ptrOffsetInfo.isStructured() && + (!ptrOffsetInfo.isScalarLike() || + llvm::all_of(ptrType.getShape(), [](int64_t dim) { return dim == 1; }))) { + if (flag) { + AddAssertForAddPtr(op, ptrOffsetInfo.getOffset(), rewriter); + return success(); + } else { + return failure(); + } + } + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Converting " << op->getName() << "\n"; + os << op << "\n"; + os << "isStructured = " << ptrOffsetInfo.isStructured() \ + << ",isScalarLike = " << ptrOffsetInfo.isScalarLike() \ + << ", isNegativeFlag = " << ptrOffsetInfo.isNegativeFlag() << "\n"; + }); + + if constexpr (std::is_same_v) + if (ptrOffsetInfo.isScalarLike()) { + if (flag) { + AddAssertForAddPtr(op, ptrOffsetInfo.getOffset(), rewriter); + } + splatAndLoadScenario(op, ptrOffsetInfo.getRank(), rewriter); + return success(); + } + + if constexpr (std::is_same_v) { + if (op->hasAttr(ConverterUtils::discreteMaskAttrName)) { + auto selectOp = op.getValue().template getDefiningOp(); + op = rewriter.replaceOpWithNewOp( + op, op.getPtr(), selectOp.getTrueValue(), selectOp.getCondition(), + op.getCache(), op.getEvict()); + rewriter.setInsertionPoint(op); + } + } + + auto srcPtr = ptrOffsetInfo.getPtr(); + auto offset = ptrOffsetInfo.getOffset(); + + // LoadLike is operation with result + bool isLoadLike = !op->use_empty(); + + Value zeroIdx = + rewriter.create(loc, rewriter.getIndexAttr(0)); + Value oneIdx = + rewriter.create(loc, rewriter.getIndexAttr(1)); + auto resultShape = ptrType.getShape(); + auto resultElementType = + cast(ptrType.getElementType()).getPointeeType(); + + Value iterArg = nullptr; + + // Only load case + if (isLoadLike) { + iterArg = + rewriter.create(loc, resultShape, resultElementType); + } + Value newOpResult = nullptr; + + auto insertPoint = rewriter.saveInsertionPoint(); + + SmallVector dims(resultShape.size(), rewriter.getIndexAttr(1)); + SmallVector offsets; + SmallVector strides; + SmallVector iterIdx; + + SmallVector localMemStrides(1, 1); + + for (auto size : llvm::reverse(resultShape)) { + localMemStrides.push_back(localMemStrides.back() * size); + } + localMemStrides.pop_back(); + + std::reverse(localMemStrides.begin(), localMemStrides.end()); + bool isExtractedAttrInserted = false; + for (const auto &[size, localMemStride] : + llvm::zip_equal(resultShape, localMemStrides)) { + // handle indirect dimension + strides.push_back(rewriter.getIndexAttr(localMemStride)); + Value sizeVal = + rewriter.create(loc, rewriter.getIndexAttr(size)); + scf::ForOp forOp; + if (isLoadLike) { + forOp = rewriter.create(loc, zeroIdx, sizeVal, oneIdx, + ValueRange({iterArg})); + if (!newOpResult) { + newOpResult = forOp->getResult(0); + } else { + rewriter.create(loc, forOp->getResult(0)); + } + iterArg = forOp.getRegionIterArg(0); + } else { + forOp = rewriter.create(loc, zeroIdx, sizeVal, oneIdx); + } + offsets.push_back(forOp.getInductionVar()); + iterIdx.push_back(forOp.getInductionVar()); + forOp->setAttr("ExtractedLoadOrStore", + UnitAttr::get(rewriter.getContext())); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + auto scalarLikeShape = SmallVector(dims.size(), 1); + auto scalarLikeType = + RankedTensorType::get(scalarLikeShape, resultElementType); + + auto extractedOffset = createExtractOp(loc, offset, iterIdx, rewriter); + if (flag) { + AddAssertForAddPtr(op, extractedOffset, rewriter); + } + if (isa(srcPtr.getType())) { + srcPtr = createExtractOp(loc, srcPtr, iterIdx, rewriter); + } + Value ptrToAccess = rewriter.create( + loc, srcPtr.getType(), srcPtr, extractedOffset); + + MemAccOpTy accessedValue = + createMemAccOp(op, ptrToAccess, loc, iterIdx, rewriter); + accessedValue->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + + if (isLoadLike) { + assert(iterArg && "Load case must have iterArg in for loop"); + + Value splatedValue = accessedValue->getResult(0); + if (!isa(splatedValue.getType())) { + splatedValue = + rewriter.create(loc, scalarLikeType, splatedValue); + } + auto result = rewriter.create( + loc, splatedValue, iterArg, offsets, dims, strides); + rewriter.create(loc, result->getResult(0)) + ->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + + rewriter.restoreInsertionPoint(insertPoint); + if constexpr (std::is_same_v) { + if (op.getMask() && op.getOther()) { + rewriter + .replaceOpWithNewOp(op, op.getMask(), newOpResult, + op.getOther()) + ->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + ; + } else { + rewriter.replaceOp(op, newOpResult); + } + } else { + rewriter.replaceOp(op, newOpResult); + } + } else { + rewriter.eraseOp(op); + } + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After conversion\n" + << ptrToAccess.getDefiningOp() + ->template getParentOfType() + << "\n"; + }); + return success(); +} + +void exchangeValueWithOffset(Value value, Value ptr, const Location &loc, + RewriterBase &rewriter, + llvm::DenseMap &offsetMap) { + RewriterBase::InsertionGuard guard(rewriter); + if (auto blockArgument = dyn_cast(value)) { + rewriter.setInsertionPointToStart(blockArgument.getOwner()); + } else { + rewriter.setInsertionPointAfter(value.getDefiningOp()); + } + auto tempVar = rewriter + .create(loc, value.getType(), + ValueRange({})) + ->getResult(0); + auto valueType = cast(value.getType()); + auto offsetType = + RankedTensorType::get(valueType.getShape(), rewriter.getIntegerType(64)); + rewriter.replaceAllUsesWith(value, tempVar); + value.setType(offsetType); + auto splatedPtr = rewriter.create(loc, valueType, ptr); + auto newPtr = + rewriter.create(loc, valueType, splatedPtr, value); + parseAddPtr(newPtr, loc, rewriter, offsetMap); + rewriter.replaceAllUsesWith(tempVar, newPtr); +} + +void TritonToUnstructurePass::runPreparse(LoopLikeOpInterface op) { + IRRewriter rewriter(&getContext()); + auto loopResults = op.getLoopResults(); + if (!loopResults) + return; + for (OpResult res : *loopResults) { + if (auto tensorType = dyn_cast(res.getType())) { + auto loc = op.getLoc(); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Pre-parsing " << op->getName() << "\n" << op << "\n"; + }); + parse(res, loc, rewriter, offsetMapForLoopArgs); + + BlockArgument regionIterArg; + auto resOffsetInfo = offsetMapForLoopArgs.at(res); + if (!resOffsetInfo.isStructured() && + isa(tensorType.getElementType())) { + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Handling special case\n" << op << '\n'; + }); + // Get initArg + OpOperand *initArg = op.getTiedLoopInit(res); + PtrOffsetInfo initOffsetInfo = offsetMapForLoopArgs.at(initArg->get()); + // Get regionIterArg + regionIterArg = op.getTiedLoopRegionIterArg(res); + PtrOffsetInfo regionIterArgOffsetInfo = + offsetMapForLoopArgs.at(regionIterArg); + // Get yield + OpOperand *yieldedValue = op.getTiedLoopYieldedValue(regionIterArg); + PtrOffsetInfo yieldOffsetInfo = + offsetMapForLoopArgs.at(yieldedValue->get()); + // Exchange iter arg with offset + exchangeValueWithOffset(regionIterArg, initOffsetInfo.getPtr(), loc, + rewriter, offsetMapForLoopArgs); + rewriter.replaceAllUsesWith(regionIterArgOffsetInfo.getOffset(), + regionIterArg); + yieldedValue->set(yieldOffsetInfo.getOffset()); + initArg->set(initOffsetInfo.getOffset()); + exchangeValueWithOffset(res, initOffsetInfo.getPtr(), loc, rewriter, + offsetMapForLoopArgs); + } + + regionIterArg = op.getTiedLoopRegionIterArg(res); + offsetMap[regionIterArg] = PtrOffsetInfo(resOffsetInfo.getPtr()); + SmallVector ®ionIterArgOffset = + offsetMap[regionIterArg].getStructuredRef(); + SmallVector &resOffset = + resOffsetInfo.getStructuredRef(); + regionIterArgOffset.resize(resOffset.size()); + for (size_t i = 0; i < resOffset.size(); i++) + regionIterArgOffset[i] = resOffset[i]; + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Pre-parsing result of\n" << regionIterArg << "\nis "; + for (size_t i = 0; i < regionIterArgOffset.size(); i++) + os << regionIterArgOffset[i]; + os << '\n'; + }); + } + } +} + +template +void TritonToUnstructurePass::runParse(MemAccOpTy op) { + IRRewriter rewriter(&getContext()); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Parsing " << op->getName() << "\n" << op << "\n"; + }); + parse(op.getPtr(), op.getLoc(), rewriter, offsetMap); +} + +void TritonToUnstructurePass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + moduleOp->walk([this](LoopLikeOpInterface op) { runPreparse(op); }); + moduleOp->walk([this](Operation *op) { + if (auto loadOp = dyn_cast(op)) { + runParse(loadOp); + } else if (auto storeOp = dyn_cast(op)) { + runParse(storeOp); + } else if (auto atomicRMWOp = dyn_cast(op)) { + runParse(atomicRMWOp); + } else if (auto atomicCASOp = dyn_cast(op)) { + runParse(atomicCASOp); + } + }); + + RewritePatternSet patterns(ctx); + + patterns.add, + UnstructuredMemAccessConverter, + UnstructuredMemAccessConverter, + UnstructuredMemAccessConverter>(ctx, + offsetMap); + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Parsing done\n"; + }); + + if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) { + moduleOp->emitError("failed to apply Patterns"); + signalPassFailure(); + } + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } +} + +void TritonToUnstructurePass::getDependentDialects( + DialectRegistry ®istry) const { + registry.insert(); +} + +std::unique_ptr> +triton::createTritonToUnstructurePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt b/third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt new file mode 100644 index 000000000..7c3bf8311 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/Utils/CMakeLists.txt @@ -0,0 +1,8 @@ +add_triton_library(MLIRTritonNPUUtils + Utils.cpp + InterleaveOptimization.cpp + + LINK_LIBS PUBLIC + MLIRIR + TritonIR +) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp b/third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp new file mode 100644 index 000000000..3cb6d188d --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/Utils/InterleaveOptimization.cpp @@ -0,0 +1,663 @@ +//===- InterleaveOptimization.cpp -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Utils/InterleaveOptimization.h" +#include "Utils/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Support/LogicalResult.h" + +#include "mlir/IR/Operation.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include +#include + +namespace mlir { +namespace triton { +// For origin MemRefType of ReinterpretCastOp under interleave state, here wanna +// adjust its shape info by expanding last dimension double. +MemRefType expandInterleaveMemRefType(MemRefType originType) { + // Double the last dimension shape + SmallVector shape(originType.getShape()); + shape.back() = shape.back() * 2; + + // Adjuest layout attribute + StridedLayoutAttr originLayout = + llvm::dyn_cast(originType.getLayout()); + // If offset is static, just reset it to 0 + auto offset = originLayout.getOffset() == ShapedType::kDynamic + ? originLayout.getOffset() + : 0; + // Set last dimension stride to 1 + SmallVector stride(originLayout.getStrides()); + stride.back() = 1; + + return MemRefType::get( + shape, originType.getElementType(), + StridedLayoutAttr::get(originType.getContext(), offset, stride)); +} + +// ********************* +// ** NOTE ** +// ********************* +// How to determine new offset is a little tricky and specific +// Here just consider this state in triton language: +// +// dim_range = tl.arange(0, BLOCK // 2) +// last_dim_even_range = dim_range * 2 +// last_dim_odd_range = dim_range * 2 + 1 +// +// Here `multiply two` represents that last dimension stride is 2, and +// `add constant one` represents whether it's odd index part of +// deinterleave result. +// +// Therefore, how to distinguish interleave/deinterleave on even index or odd +// index is whether last dimension range explicitly `add constant one` without +// any other operation. In IR it's shown that whether defining op of +// `castOffset` is an arith::addOp, as this arith::addOp would contain above +// `add constant one` opeartion after LegacyAddPtrConverter. +// +// Well, index mode should be passed to interleave/deinterleave, in other words, +// `add constant one` should work on offset of next insert_slice/extract_slic. +// The new reinterpretcast just wanna describe whole tensor, so new castOffset +// is just from non-last diemsnion accumulation and remove `add constant one` +std::pair +recountReinterpretCastOffset(OpFoldResult originOffset, Builder &builder) { + // To trace value type offset + std::function traceOffset = [&](Operation *op) -> bool { + // Consider constant one in `add constant one` operation + if (llvm::isa(op)) + return false; + + if (llvm::isa(op)) { + auto addOp = llvm::cast(op); + if (auto constLHS = addOp.getLhs().getDefiningOp()) { + assert(dyn_cast(constLHS.getValueAttr()).getInt() == 1 && + "Arith::constant value of addi's operand must be 1 when " + "calculate deinterleave offset"); + return false; + } + if (auto constRHS = addOp.getRhs().getDefiningOp()) { + assert(dyn_cast(constRHS.getValueAttr()).getInt() == 1 && + "Arith::constant value of addi's operand must be 1 when " + "calculate deinterleave offset"); + return false; + } + } + return true; + }; + + IndexMode evenOrOdd = IndexMode::EVEN_MODE; + // Reuse origin offset if there's no 'add constant one' + OpFoldResult newOffset = originOffset; + if (llvm::isa(originOffset)) { + // If offset is constant int(IndexAttr), + // the int value could only be 0 or 1 + int64_t intOffset = getConstantIntValue(originOffset).value(); + assert((intOffset == 0 || intOffset == 1)); + if (intOffset == 1) { + evenOrOdd = IndexMode::ODD_MODE; + newOffset = builder.getIndexAttr(0); + } + } else if (llvm::isa(originOffset)) { + if (!traceOffset(originOffset.get().getDefiningOp())) { + evenOrOdd = IndexMode::ODD_MODE; + Operation *traceResult = findFirstMatchingOperandDef( + originOffset.get().getDefiningOp(), traceOffset); + assert(traceResult->getNumResults() == 1 && + "Offset defining operation must have one result"); + newOffset = traceResult->getResult(0); + } + } + + return {newOffset, evenOrOdd}; +} + +LogicalResult +DeinterleaveStatusOptimization(triton::LoadOp op, + triton::LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) { + auto ptr = adaptor.getPtr(); + if (auto reinterpretCast = ptr.getDefiningOp()) { + auto loc = op.getLoc(); + + // 1. Get new source memref type + auto srcType = expandInterleaveMemRefType(reinterpretCast.getType()); + + // 2. Create new ReinterpretCastOp + auto originCastOffset = reinterpretCast.getConstifiedMixedOffset(); + auto castSize = reinterpretCast.getConstifiedMixedSizes(); + auto castStride = reinterpretCast.getConstifiedMixedStrides(); + // Actually, `castSize` is always constant value as `MemRefType` result + if (auto lastDimSize = getConstantIntValue(castSize.back())) { + castSize.back() = rewriter.getIndexAttr(lastDimSize.value() * 2); + } else { + return failure(); + } + // Last element of castStride is also constant value as prerequisite + // is that last dimension stride of casted memref type is always 2. + castStride.back() = rewriter.getIndexAttr(1); + auto [castOffset, indexMode] = + recountReinterpretCastOffset(originCastOffset, rewriter); + auto newCastOp = rewriter.create( + loc, srcType, reinterpretCast.getViewSource(), castOffset, castSize, + castStride); + + // 3. Create new memref allocOp + auto newAllocOp = rewriter.create( + loc, MemRefType::get(srcType.getShape(), srcType.getElementType())); + + // 4. Implement memref copy and bufferization back to tensor + rewriter.create(loc, newCastOp.getResult(), newAllocOp); + Value newTensor = rewriter.create( + loc, + RankedTensorType::get(srcType.getShape(), srcType.getElementType()), + newAllocOp, true /* restrict */, true /* writable */); + + // 5. Implement tensor extract_slice to represent deinterleave + // Here use `castOffset` to determine whether even index deinterleave or + // odd index. + SmallVector extractOffsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector extractStrides(srcType.getRank(), + rewriter.getIndexAttr(1)); + SmallVector extractSizes = llvm::to_vector( + llvm::map_range(srcType.getShape(), [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + // Adjust extract_slice shape + switch (indexMode) { + case IndexMode::EVEN_MODE: + extractOffsets.back() = rewriter.getIndexAttr(0); + break; + case IndexMode::ODD_MODE: + extractOffsets.back() = rewriter.getIndexAttr(1); + break; + } + extractStrides.back() = rewriter.getIndexAttr(2); + extractSizes.back() = rewriter.getIndexAttr(srcType.getShape().back() / 2); + + Value deinterleaveSlice = rewriter.create( + loc, newTensor, extractOffsets, extractSizes, extractStrides); + + rewriter.replaceOp(op, deinterleaveSlice); + return success(); + } + + return failure(); +} + +LogicalResult DeinterleaveStatusWithMaskOptimization( + triton::LoadOp op, triton::LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter, MaskState &mstate, + memref::AllocOp originAllocOp) { + auto ptr = adaptor.getPtr(); + if (auto reinterpretCast = ptr.getDefiningOp()) { + auto loc = op.getLoc(); + + // 1. Get new source memref type + auto srcType = expandInterleaveMemRefType(reinterpretCast.getType()); + + // 2. Create new ReinterpretCastOp + auto originCastOffset = reinterpretCast.getConstifiedMixedOffset(); + auto castSize = reinterpretCast.getConstifiedMixedSizes(); + auto castStride = reinterpretCast.getConstifiedMixedStrides(); + + if (auto lastDimSize = getConstantIntValue(castSize.back())) { + castSize.back() = rewriter.getIndexAttr(lastDimSize.value() * 2); + } else { + return failure(); + } + castStride.back() = rewriter.getIndexAttr(1); + auto [castOffset, indexMode] = + recountReinterpretCastOffset(originCastOffset, rewriter); + + auto newCastOp = rewriter.create( + loc, srcType, reinterpretCast.getViewSource(), castOffset, castSize, + castStride); + + // 3. Create new memref allocOp + // To reuse existing linalg::fill, here need to change insertion point + auto savedInsertPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(originAllocOp); + auto newAllocOp = rewriter.create( + loc, MemRefType::get(srcType.getShape(), srcType.getElementType())); + rewriter.restoreInsertionPoint(savedInsertPoint); + + // 4. Broadcast other value by linalg.fill if necessary + auto other = op.getOther(); + // While deinterleave optimization will just adjust last dimension info + // and origin mask state wouldn't involve last dimension. Therefore in + // current `scf.if + linalg.fill` combination, condition of `if` could be + // kept and just replace linalg.fill' + if (other) { + assert(originAllocOp->hasOneUse() && + llvm::isa(*(originAllocOp->getUsers().begin()))); + auto originFillOp = + llvm::dyn_cast(*(originAllocOp->getUsers().begin())); + + assert(llvm::isa(originFillOp->getParentOp())); + auto ifOp = llvm::dyn_cast(originFillOp->getParentOp()); + + auto newFillOp = ifOp.getThenBodyBuilder().create( + originFillOp.getLoc(), originFillOp.getInputs(), + ValueRange{newAllocOp}); + rewriter.replaceOp(originFillOp, newFillOp); + } + + // 5. Implement new subview, memref copy and bufferization back to tensor + SmallVector subviewStrides(srcType.getRank(), + rewriter.getIndexAttr(1)); + SmallVector subviewOffsets = mstate.offsets; + SmallVector subviewSizes = mstate.dims; + // Just adjust last dimension size to double + std::optional originSubviewLastDim = + getConstantIntValue(subviewSizes.back()); + assert(originSubviewLastDim.has_value()); + subviewSizes.back() = + rewriter.getIndexAttr(originSubviewLastDim.value() * 2); + + auto argSubviewType = memref::SubViewOp::inferResultType( + srcType, subviewOffsets, subviewSizes, subviewStrides); + // alloca subview type doesn't carry layout attribute + auto allocSubviewType = memref::SubViewOp::inferResultType( + newAllocOp.getType(), subviewOffsets, subviewSizes, subviewStrides); + + memref::SubViewOp srcSubview = rewriter.create( + loc, llvm::cast(argSubviewType), newCastOp, subviewOffsets, + subviewSizes, subviewStrides); + memref::SubViewOp dstSubview = rewriter.create( + loc, llvm::cast(allocSubviewType), newAllocOp, + subviewOffsets, subviewSizes, subviewStrides); + rewriter.create(loc, srcSubview, dstSubview); + Value newTensor = rewriter.create( + loc, + RankedTensorType::get(srcType.getShape(), srcType.getElementType()), + newAllocOp, true /* restrict */, true /* writable */); + + // 6. Implement tensor extract_slice to represent deinterleave + // Here use `castOffset` to determine whether even index deinterleave or + // odd index. + SmallVector extractOffsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector extractStrides(srcType.getRank(), + rewriter.getIndexAttr(1)); + SmallVector extractSizes = llvm::to_vector( + llvm::map_range(srcType.getShape(), [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + switch (indexMode) { + case IndexMode::EVEN_MODE: + extractOffsets.back() = rewriter.getIndexAttr(0); + break; + case IndexMode::ODD_MODE: + extractOffsets.back() = rewriter.getIndexAttr(1); + break; + } + extractStrides.back() = rewriter.getIndexAttr(2); + extractSizes.back() = rewriter.getIndexAttr(srcType.getShape().back() / 2); + + Value deinterleaveSlice = rewriter.create( + loc, newTensor, extractOffsets, extractSizes, extractStrides); + + rewriter.replaceOp(op, deinterleaveSlice); + return success(); + } + return failure(); +} + +LogicalResult +InterleaveStatusOptimization(SmallVector materializeVec) { + OpBuilder builder(materializeVec[1]); + auto loc = materializeVec[1]->getLoc(); + + auto firstReinterpretCastOp = + llvm::dyn_cast( + materializeVec[0]) + .getDest() + .getDefiningOp(); + auto secondReinterpretCastOp = + llvm::dyn_cast( + materializeVec[1]) + .getDest() + .getDefiningOp(); + + assert(firstReinterpretCastOp && secondReinterpretCastOp); + + // Judge whether two `ReinterpretCastOp` shape satisfy interleave state + // a. both size are equal + if (!isEqualConstantIntOrValueArray( + firstReinterpretCastOp.getConstifiedMixedSizes(), + secondReinterpretCastOp.getConstifiedMixedSizes())) { + return failure(); + } + // b. both strides are equal + if (!isEqualConstantIntOrValueArray( + firstReinterpretCastOp.getConstifiedMixedStrides(), + secondReinterpretCastOp.getConstifiedMixedStrides())) { + return failure(); + } + // c. both offsets should satisfy tricky rule + auto firstOriginCastOffset = + firstReinterpretCastOp.getConstifiedMixedOffset(); + auto secondOriginCastOffset = + secondReinterpretCastOp.getConstifiedMixedOffset(); + std::pair indexModeRecord; + OpFoldResult newCastOffset; + if (llvm::isa(firstOriginCastOffset) && + llvm::isa(secondOriginCastOffset)) { + auto [firstCastOffset, firstIndexMode] = + recountReinterpretCastOffset(firstOriginCastOffset, builder); + auto [secondCastOffset, secondIndexMode] = + recountReinterpretCastOffset(secondOriginCastOffset, builder); + + if (!(static_cast(firstIndexMode) ^ static_cast(secondIndexMode))) + return failure(); + newCastOffset = builder.getIndexAttr(0); + indexModeRecord = {firstIndexMode, secondIndexMode}; + + } else if (llvm::isa(firstOriginCastOffset) && + llvm::isa(secondOriginCastOffset)) { + auto [firstCastOffset, firstIndexMode] = + recountReinterpretCastOffset(firstOriginCastOffset, builder); + auto [secondCastOffset, secondIndexMode] = + recountReinterpretCastOffset(secondOriginCastOffset, builder); + + if (!(static_cast(firstIndexMode) ^ + static_cast(secondIndexMode)) || + (llvm::dyn_cast(firstCastOffset) != + llvm::dyn_cast(secondCastOffset))) + return failure(); + + if (firstIndexMode == IndexMode::EVEN_MODE) { + newCastOffset = llvm::dyn_cast(firstCastOffset); + } + if (secondIndexMode == IndexMode::EVEN_MODE) { + newCastOffset = llvm::dyn_cast(secondCastOffset); + } + indexModeRecord = {firstIndexMode, secondIndexMode}; + + } else { + return failure(); + } + + // Create new op + // 1. Get new destination memref type + auto dstType = expandInterleaveMemRefType(firstReinterpretCastOp.getType()); + + // 2. New tensor::EmptyOp + auto emptyTensor = builder.create(loc, dstType.getShape(), + dstType.getElementType()); + + // 3. New insert_slice from materialization source into new empty tensor + SmallVector insertOffsets(dstType.getRank(), + builder.getIndexAttr(0)); + SmallVector insertStrides(dstType.getRank(), + builder.getIndexAttr(1)); + SmallVector insertSizes = llvm::to_vector( + llvm::map_range(dstType.getShape(), [&](int64_t dim) -> OpFoldResult { + return builder.getIndexAttr(dim); + })); + insertStrides.back() = builder.getIndexAttr(2); + insertSizes.back() = builder.getIndexAttr(dstType.getShape().back() / 2); + if (indexModeRecord.first == IndexMode::ODD_MODE) { + insertOffsets.back() = builder.getIndexAttr(1); + } else { + insertOffsets.back() = builder.getIndexAttr(0); + } + auto insertFirst = builder.create( + loc, + llvm::dyn_cast( + materializeVec[0]) + .getSource(), + emptyTensor.getResult(), insertOffsets, insertSizes, insertStrides); + + if (indexModeRecord.second == IndexMode::ODD_MODE) { + insertOffsets.back() = builder.getIndexAttr(1); + } else { + insertOffsets.back() = builder.getIndexAttr(0); + } + auto insertSecond = builder.create( + loc, + llvm::dyn_cast( + materializeVec[1]) + .getSource(), + insertFirst.getResult(), insertOffsets, insertSizes, insertStrides); + + // 4. Reinterpret_cast block arg + auto newCastSize = firstReinterpretCastOp.getConstifiedMixedSizes(); + auto newCastStride = firstReinterpretCastOp.getConstifiedMixedStrides(); + newCastSize.back() = builder.getIndexAttr(dstType.getShape().back()); + newCastStride.back() = builder.getIndexAttr(1); + auto newCastOp = builder.create( + loc, dstType, firstReinterpretCastOp.getViewSource(), newCastOffset, + newCastSize, newCastStride); + + // 5. Create new bufferization::MaterializeInDestinationOp + auto newStoreOp = builder.create( + loc, insertSecond.getResult(), newCastOp.getResult()); + // Setting writable is necessary as dst is memref type + newStoreOp.setWritable(true); + + // 6. Erase origin materialization + materializeVec[0]->erase(); + materializeVec[1]->erase(); + + return success(); +} + +LogicalResult +InterleaveStatusWithMaskOptimization(SmallVector materializeVec) { + OpBuilder builder(materializeVec[1]); + + auto firstSubviewOpOfReCast = + llvm::dyn_cast( + materializeVec[0]) + .getDest() + .getDefiningOp(); + auto firstSrcExtractSlice = + llvm::dyn_cast( + materializeVec[0]) + .getSource() + .getDefiningOp(); + auto firstReinterpretCastOp = firstSubviewOpOfReCast.getSource() + .getDefiningOp(); + + auto secondSubviewOpOfReCast = + llvm::dyn_cast( + materializeVec[1]) + .getDest() + .getDefiningOp(); + auto secondSrcExtractSlice = + llvm::dyn_cast( + materializeVec[1]) + .getSource() + .getDefiningOp(); + auto secondReinterpretCastOp = + secondSubviewOpOfReCast.getSource() + .getDefiningOp(); + + // 1. Both source shapes of subview and extract_slice are equal + if (firstSubviewOpOfReCast.getSourceType().getShape() != + firstSrcExtractSlice.getSourceType().getShape()) + return failure(); + if (secondSubviewOpOfReCast.getSourceType().getShape() != + secondSrcExtractSlice.getSourceType().getShape()) + return failure(); + if (firstSubviewOpOfReCast.getSourceType().getShape() != + secondSubviewOpOfReCast.getSourceType().getShape()) + return failure(); + + // 2. both mask state are equal + std::function cmpFunc = + mlir::isEqualConstantIntOrValue; + if (!mlir::detail::sameOffsetsSizesAndStrides(firstSubviewOpOfReCast, + firstSrcExtractSlice, cmpFunc)) + return failure(); + if (!mlir::detail::sameOffsetsSizesAndStrides(secondSubviewOpOfReCast, + secondSrcExtractSlice, cmpFunc)) + return failure(); + if (!mlir::detail::sameOffsetsSizesAndStrides( + firstSubviewOpOfReCast, secondSubviewOpOfReCast, cmpFunc)) + return failure(); + + // 3. Still judge whether two `ReinterpretCastOp` shape satisfy request + // a. both size are equal + if (!isEqualConstantIntOrValueArray( + firstReinterpretCastOp.getConstifiedMixedSizes(), + secondReinterpretCastOp.getConstifiedMixedSizes())) + return failure(); + // b. both strides are equal + if (!isEqualConstantIntOrValueArray( + firstReinterpretCastOp.getConstifiedMixedStrides(), + secondReinterpretCastOp.getConstifiedMixedStrides())) + return failure(); + // c. both offsets should satisfy tricky rule + auto firstOriginCastOffset = + firstReinterpretCastOp.getConstifiedMixedOffset(); + auto secondOriginCastOffset = + secondReinterpretCastOp.getConstifiedMixedOffset(); + std::pair indexModeRecord; + OpFoldResult newCastOffset; + if (llvm::isa(firstOriginCastOffset) && + llvm::isa(secondOriginCastOffset)) { + auto [firstCastOffset, firstIndexMode] = + recountReinterpretCastOffset(firstOriginCastOffset, builder); + auto [secondCastOffset, secondIndexMode] = + recountReinterpretCastOffset(secondOriginCastOffset, builder); + + if (!(static_cast(firstIndexMode) ^ static_cast(secondIndexMode))) + return failure(); + newCastOffset = builder.getIndexAttr(0); + indexModeRecord = {firstIndexMode, secondIndexMode}; + + } else if (llvm::isa(firstOriginCastOffset) && + llvm::isa(secondOriginCastOffset)) { + auto [firstCastOffset, firstIndexMode] = + recountReinterpretCastOffset(firstOriginCastOffset, builder); + auto [secondCastOffset, secondIndexMode] = + recountReinterpretCastOffset(secondOriginCastOffset, builder); + + if (!(static_cast(firstIndexMode) ^ + static_cast(secondIndexMode)) || + (llvm::dyn_cast(firstCastOffset) != + llvm::dyn_cast(secondCastOffset))) + return failure(); + + if (firstIndexMode == IndexMode::EVEN_MODE) { + newCastOffset = llvm::dyn_cast(firstCastOffset); + } + if (secondIndexMode == IndexMode::EVEN_MODE) { + newCastOffset = llvm::dyn_cast(secondCastOffset); + } + indexModeRecord = {firstIndexMode, secondIndexMode}; + + } else { + return failure(); + } + auto loc = materializeVec[1]->getLoc(); + + // Create new op + // 1. Get new destination memref type + auto dstType = expandInterleaveMemRefType(firstReinterpretCastOp.getType()); + + // 2. New tensor::EmptyOp + auto emptyTensor = builder.create(loc, dstType.getShape(), + dstType.getElementType()); + + // 3. New insert_slice from extract_slice source into new empty tensor + SmallVector insertOffsets(dstType.getRank(), + builder.getIndexAttr(0)); + SmallVector insertStrides(dstType.getRank(), + builder.getIndexAttr(1)); + SmallVector insertSizes = llvm::to_vector( + llvm::map_range(dstType.getShape(), [&](int64_t dim) -> OpFoldResult { + return builder.getIndexAttr(dim); + })); + insertStrides.back() = builder.getIndexAttr(2); + insertSizes.back() = builder.getIndexAttr(dstType.getShape().back() / 2); + if (indexModeRecord.first == IndexMode::ODD_MODE) { + insertOffsets.back() = builder.getIndexAttr(1); + } else { + insertOffsets.back() = builder.getIndexAttr(0); + } + auto insertFirst = builder.create( + loc, firstSrcExtractSlice.getSource(), emptyTensor.getResult(), + insertOffsets, insertSizes, insertStrides); + + if (indexModeRecord.second == IndexMode::ODD_MODE) { + insertOffsets.back() = builder.getIndexAttr(1); + } else { + insertOffsets.back() = builder.getIndexAttr(0); + } + auto insertSecond = builder.create( + loc, secondSrcExtractSlice.getSource(), insertFirst.getResult(), + insertOffsets, insertSizes, insertStrides); + + // 4. To enable store with mask, create new extract_slice + SmallVector extractOffsets = + firstSrcExtractSlice.getMixedOffsets(); + SmallVector extractStrides = + firstSrcExtractSlice.getMixedStrides(); + SmallVector extractSizes = firstSrcExtractSlice.getMixedSizes(); + assert(llvm::isa(extractSizes.back())); + extractSizes.back() = builder.getIndexAttr( + getConstantIntValue(extractSizes.back()).value() * 2); + auto newSrcExtractSlice = builder.create( + loc, insertSecond.getResult(), extractOffsets, extractSizes, + extractStrides); + + // 5. Reinterpret_cast block arg + auto newCastSize = firstReinterpretCastOp.getConstifiedMixedSizes(); + auto newCastStride = firstReinterpretCastOp.getConstifiedMixedStrides(); + newCastSize.back() = builder.getIndexAttr(dstType.getShape().back()); + newCastStride.back() = builder.getIndexAttr(1); + auto newCastOp = builder.create( + loc, dstType, firstReinterpretCastOp.getViewSource(), newCastOffset, + newCastSize, newCastStride); + + // 6. Create new memref::SubViewOp of above new reinterpret_cast + // Here could reuse shape info of new extract_slice + auto dstSubviewType = memref::SubViewOp::inferResultType( + dstType, extractOffsets, extractSizes, extractStrides); + auto newSubviewOpOfReCast = builder.create( + loc, llvm::cast(dstSubviewType), newCastOp, extractOffsets, + extractSizes, extractStrides); + + // 7. Create new bufferization::MaterializeInDestinationOp + auto newStoreOp = builder.create( + loc, newSrcExtractSlice.getResult(), newSubviewOpOfReCast.getResult()); + // Setting writable is necessary as dst is memref type + newStoreOp.setWritable(true); + + // 8. Erase origin operation + materializeVec[0]->erase(); + materializeVec[1]->erase(); + firstSubviewOpOfReCast->erase(); + firstSrcExtractSlice->erase(); + secondSubviewOpOfReCast->erase(); + secondSrcExtractSlice->erase(); + + return success(); +} + +} // namespace triton +} // namespace mlir \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/lib/Utils/Utils.cpp b/third_party/ascend/triton-adapter/lib/Utils/Utils.cpp new file mode 100644 index 000000000..af1e617a7 --- /dev/null +++ b/third_party/ascend/triton-adapter/lib/Utils/Utils.cpp @@ -0,0 +1,867 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "Utils/Utils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" + +#include +#include +#include +#include +#include +#include + +#define DEBUG_TYPE "TritonNPU-Utils" + +namespace mlir { + +static Value createConstIndexValueOp(const Location &loc, OpBuilder &b, + int64_t value) { + return b.create(loc, b.getIndexAttr(value)).getResult(); +} + +static std::optional getConstantOfAttr(const OpFoldResult &arg) { + if (isa(arg)) { + return getConstantIntValue(arg); + } + + return std::nullopt; +} + +namespace ConverterUtils { + +std::optional getLastStrideOfReinterpretCastOp(memref::ReinterpretCastOp op) { + SmallVector mixedStrides = op.getMixedStrides(); + if (mixedStrides.empty()) { + op->emitError("ReinterpretCastOp has no strides"); + return std::nullopt; + } + + OpFoldResult lastStride = mixedStrides.back(); + if (auto attr = lastStride.dyn_cast()) { + return getConstantOfAttr(lastStride); + } else if (auto value = lastStride.dyn_cast()) { + auto defOp = value.getDefiningOp(); + if (auto constIndexOp = dyn_cast(defOp)) { + int64_t constValue = constIndexOp.value(); + return constValue; + } + else if (auto constIntOp = dyn_cast(defOp)) { + int64_t constValue = constIntOp.value(); + return constValue; + } + } + return std::nullopt; +} + +bool isaPermutedMemRefType(MemRefType memRefType) { + auto [ptrStrides, ptrOffsets] = getStridesAndOffset(memRefType); + LLVM_DEBUG({ + llvm::dbgs()<<"---------- [BEG] ptrStrides ----------\n"; + for(auto stride: ptrStrides)llvm::dbgs()< order) { + auto sourceType = cast(source.getType()); + auto sourceRank = sourceType.getRank(); + + SmallVector perm(order); + SmallVector originalShape(sourceType.getShape()); + SmallVector transposedShape(sourceRank); + for (size_t i = 0; i < sourceRank; i++) { + transposedShape[i] = originalShape[perm[i]]; + } + + Value transposeInit = rewriter.create( + loc, transposedShape, sourceType.getElementType()); + + Value transpose = + rewriter.create(loc, source, transposeInit, perm) + .getResults()[0]; + + return transpose; +} + +SmallVector getNParallelLoopsAttrs(unsigned n) { + return SmallVector(n, utils::IteratorType::parallel); +} + +Value getScalarValue(Value operand, Location loc, + ConversionPatternRewriter &rewriter) { + SmallVector ops; + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = mlir::TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + rewriter, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +memref::SubViewOp makeSubViewOp(Value src, + const llvm::SmallVector &sizes, + const Location &loc, + ConversionPatternRewriter &rewriter) { + auto srcType = cast(src.getType()); + SmallVector offsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector strides(srcType.getRank(), + rewriter.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return rewriter.create(loc, dyn_cast(dstType), + src, offsets, sizes, strides); +} + +tensor::ExtractSliceOp makeExtractSliceOp(Value src, + const llvm::SmallVector &sizes, + const Location &loc, + ConversionPatternRewriter &rewriter) { + auto srcType = cast(src.getType()); + SmallVector offsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector strides(srcType.getRank(), + rewriter.getIndexAttr(1)); + auto dstType = + tensor::ExtractSliceOp::inferResultType(srcType, offsets, sizes, strides); + return rewriter.create(loc, dstType, src, offsets, + sizes, strides); +} + +std::optional getFullShapeOp(Value val, + ConversionPatternRewriter &rewriter) { + assert(isa(val.getType())); + + if (isa(val)) { + auto blockArg = dyn_cast(val); + auto blockOp = blockArg.getOwner()->getParentOp(); + if (isa(blockOp)) { + auto forOp = dyn_cast(blockOp); + auto operand = forOp.getTiedLoopInit(blockArg)->get(); + return getFullShapeOp(operand, rewriter); + } else { + emitError(val.getLoc()) + << "getFullShapeOp() only support ReinterpretCastOp " + "and scf.for's block argument, but got : " + << val << "\n"; + } + return std::nullopt; + } + + if (!isa(val.getDefiningOp())) { + emitError(val.getLoc()) + << "getFullShapeOp() only support ReinterpretCastOp " + "and scf.for's block argument, but got : " + << val << "\n"; + return std::nullopt; + } + + auto reCastOp = val.getDefiningOp(); + if (reCastOp->hasAttr("tensor_ptr_full_shape")) + return reCastOp; + + return getFullShapeOp(reCastOp.getSource(), rewriter); +} + +SmallVector +getBoundarySizes(llvm::ArrayRef boundaryCheck, Value ptr, + const Location &loc, ConversionPatternRewriter &rewriter) { + if (isa(ptr.getType())) + ptr = rewriter.getRemappedValue(ptr); + + auto shapedType = dyn_cast_if_present(ptr.getType()); + assert(shapedType && shapedType.hasStaticShape()); + + auto fullShapeOp = getFullShapeOp(ptr, rewriter); + + assert(fullShapeOp.has_value()); + SmallVector boundarySize = + getAsIndexOpFoldResult(rewriter.getContext(), shapedType.getShape()); + + auto fullShapeReCast = + dyn_cast(fullShapeOp.value()); + OpFoldResult curPtrOffset; + if (auto curReCast = ptr.getDefiningOp()) { + curPtrOffset = curReCast.getConstifiedMixedOffset(); + } else if (isa(ptr) && + isa(ptr.getParentBlock()->getParentOp())) { + // Here's to process loop state where ptr is just from loop interator. + // Following assertion corresponds to conversion result from `rewriteFor` + auto blockArg = dyn_cast(ptr); + auto forOp = dyn_cast(ptr.getParentBlock()->getParentOp()); + auto initReCastOfLoop = forOp.getTiedLoopInit(blockArg) + ->get() + .getDefiningOp(); + assert(initReCastOfLoop && initReCastOfLoop.getOffsets().size() == 1); + Value initReCastOffset = initReCastOfLoop.getOffsets()[0]; + + for (OpOperand &use : initReCastOffset.getUses()) { + if (use.getOwner() == initReCastOfLoop) + continue; + else if (isa(use.getOwner())) + continue; + else if (use.getOwner() == forOp) + curPtrOffset = OpFoldResult(forOp.getTiedLoopRegionIterArg(&use)); + else + llvm_unreachable("Illegal interation offset after rewriteFor"); + } + } else { + llvm_unreachable("Unsupported state when check tensor_ptr boundary"); + } + + assert(curPtrOffset); + + OpFoldResult offsetShift = subOpFoldResult( + curPtrOffset, fullShapeReCast.getConstifiedMixedOffset(), loc, rewriter); + + for (int i = 0; i < shapedType.getRank(); ++i) { + if (llvm::find(boundaryCheck, i) != boundaryCheck.end()) { + auto fullShape = fullShapeReCast.getConstifiedMixedSizes()[i]; + + OpFoldResult curOffset = divOpFoldResult( + offsetShift, fullShapeReCast.getConstifiedMixedStrides()[i], loc, + rewriter); + OpFoldResult curLeftSize = + maxOpFoldResult(subOpFoldResult(fullShape, curOffset, loc, rewriter), + rewriter.getIndexAttr(0), loc, rewriter); + + boundarySize[i] = + minOpFoldResult(boundarySize[i], curLeftSize, loc, rewriter); + + offsetShift = remOpFoldResult( + offsetShift, fullShapeReCast.getConstifiedMixedStrides()[i], loc, + rewriter); + } + } + + return boundarySize; +} + +SmallVector getBroadcastDims(RankedTensorType src, + RankedTensorType dst) { + SmallVector broadcastDims; + auto srcShape = src.getShape(); + auto dstShape = dst.getShape(); + + for (size_t i = 0; i < srcShape.size(); ++i) { + if (dstShape[i] != srcShape[i]) { + assert(srcShape[i] == 1 && + "Size of source broadcast dimension must be 1"); + broadcastDims.push_back(i); + } + } + assert(!broadcastDims.empty() && "Cannot identify broadcast dimension"); + return broadcastDims; +} + +// Dimensions of collapesd tensor is all unbroadcast dims +SmallVector getUnbroadcastDims(RankedTensorType src, + RankedTensorType dst) { + SmallVector unbroadcastDims; + auto srcShape = src.getShape(); + auto dstShape = dst.getShape(); + + for (size_t i = 0; i < srcShape.size(); ++i) { + if (dstShape[i] == srcShape[i]) { + unbroadcastDims.emplace_back(srcShape[i]); + } + } + return unbroadcastDims; +} + +} // namespace ConverterUtils + +namespace triton { + +mlir::Operation * +findFirstMatchingOperandDef(mlir::Operation *rootOp, + const std::function &condFn) { + LLVM_DEBUG(llvm::dbgs() << "[findFirstMatchingOperandDef] Current op: " + << *rootOp << "\n"); + mlir::Value lhs = nullptr; + mlir::Value rhs = nullptr; + if (auto op = dyn_cast(rootOp)) { + lhs = op.getPtr(); + rhs = op.getOffset(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getSrc(); + } else if (auto op = dyn_cast(rootOp)) { + } else { + rootOp->emitRemark("Backtracing encounters unsupported Operation"); + return nullptr; + } + // Backtrace operands + if (!lhs) { + return nullptr; + } + auto lhsDef = lhs.getDefiningOp(); + mlir::Operation *targetOp; + if (lhsDef) { + if (condFn(lhsDef)) { + targetOp = lhsDef; + } else { + targetOp = findFirstMatchingOperandDef(lhsDef, condFn); + } + if (targetOp) { + return targetOp; + } + } + if (!rhs) { + return nullptr; + } + auto rhsDef = rhs.getDefiningOp(); + if (rhsDef) { + if (condFn(rhsDef)) { + targetOp = rhsDef; + } else { + targetOp = findFirstMatchingOperandDef(rhsDef, condFn); + } + if (targetOp) { + return targetOp; + } + } + return nullptr; +} + +void traverseBackwardUpdateOperandChainIf( + Operation *op, std::function conditionFn, + std::function stopFn, + std::function actionFn, OpBuilder &builder, + DenseSet &handledOperation) { + + if (!op || handledOperation.contains(op)) + return; + + handledOperation.insert(op); + + if (stopFn(op)) + return; + + if (conditionFn(op)) + actionFn(builder, op); + + DenseSet handledOperand; + + std::function handler = [&](Value operand) { + if (handledOperand.contains(operand)) + return; + handledOperand.insert(operand); + if (Operation *defOp = operand.getDefiningOp()) { + traverseBackwardUpdateOperandChainIf(defOp, conditionFn, stopFn, actionFn, + builder, handledOperation); + } else { + auto blockArgument = cast(operand); + if (auto loopOp = + dyn_cast(blockArgument.getOwner()->getParentOp())) { + OpOperand *initArgOperand = loopOp.getTiedLoopInit(blockArgument); + if (!initArgOperand) + return; + Value initArg = initArgOperand->get(); + handler(initArg); + Value yieldedValue = loopOp.getTiedLoopYieldedValue(blockArgument)->get(); + if (yieldedValue != blockArgument) + handler(yieldedValue); + } + } + }; + + for (Value operand : op->getOperands()) { + handler(operand); + } + + if (auto loopOp = dyn_cast(op)) { + for (auto yieldedValue: loopOp.getYieldedValues()) + handler(yieldedValue); + } +} + +// Note: rootOp will also be processed. +void traverseBackwardUpdateOperandChainIf( + Operation *rootOp, std::function conditionFn, + std::function stopFn, + std::function actionFn) { + + OpBuilder builder(rootOp->getContext()); + DenseSet handledOperation; + + traverseBackwardUpdateOperandChainIf(rootOp, conditionFn, stopFn, actionFn, builder, + handledOperation); +} + +void traverseForwardUpdateUserChainIf( + Operation *op, std::function conditionFn, + std::function stopFn, + std::function actionFn, OpBuilder &builder, + llvm::SmallPtrSet &stopOps) { + + if (!op) { + return; + } + + if (stopFn(op)) { + stopOps.insert(op); + return; + } + + if (conditionFn(op)) { + actionFn(builder, op); + } + + for (auto res : op->getResults()) { + for (auto userOp : res.getUsers()) { + traverseForwardUpdateUserChainIf(userOp, conditionFn, stopFn, actionFn, + builder, stopOps); + } + } +} + +// Note: rootOp will also be processed. +void traverseForwardUpdateUserChainIf( + Operation *rootOp, std::function conditionFn, + std::function stopFn, + std::function actionFn, + llvm::SmallPtrSet &stopOps) { + + OpBuilder builder(rootOp->getContext()); + + traverseForwardUpdateUserChainIf(rootOp, conditionFn, stopFn, actionFn, + builder, stopOps); +} + +bool isMetaUse(Operation *op) { return op->hasAttr("MetaUse"); } + +bool isMixUse(Operation *op) { return op->hasAttr("MixUse"); } + +IndirectLoadInterfaceOpType getIndirectLoadInterfaceOpType(Operation *op) { + auto ty = IndirectLoadInterfaceOpType::Undefined; + if (isMetaUse(op)) { + if (isa(op)) { + ty = IndirectLoadInterfaceOpType::Load; + } else if (isa(op)) { + ty = IndirectLoadInterfaceOpType::Calc; + } + } + return ty; +} + +bool opIsIndirectLoad(Operation *op) { + auto opType = getIndirectLoadInterfaceOpType(op); + return opType == IndirectLoadInterfaceOpType::Load; +} + +bool opIsIndirectCalc(Operation *op) { + auto opType = getIndirectLoadInterfaceOpType(op); + return opType == IndirectLoadInterfaceOpType::Calc; +} + +scf::ForOp createNestedLoops( + OpBuilder &builder, Location loc, unsigned currentDim, unsigned totalDims, + ValueRange LBs, ValueRange UBs, ValueRange steps, SmallVector &ivs, + ValueRange initArgs, + function_ref &, ValueRange)> + bodyBuilder) { + + if (currentDim >= totalDims) { + bodyBuilder(builder, loc, ivs, initArgs); + return nullptr; + } + + auto loop = builder.create( + loc, LBs[currentDim], UBs[currentDim], steps[currentDim], initArgs, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange iterArgs) { + ivs.push_back(iv); + auto innerLoop = createNestedLoops(nestedBuilder, nestedLoc, + currentDim + 1, totalDims, LBs, UBs, + steps, ivs, iterArgs, bodyBuilder); + if (innerLoop) { + nestedBuilder.create(loc, innerLoop.getResults()); + } + }); + + return loop; +} + +ModuleOp getModuleOpFromOperation(Operation *op) { + Operation *parent = op; + while (parent != nullptr && !isa(parent)) { + parent = parent->getParentOp(); // 向上查找 + } + return cast(parent); // 如果没找到会抛出异常 +} + +} // namespace triton + + +// TODO: imply these function below +OpFoldResult addOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() + rhsInt.value()); + + if (!lhsInt && rhsInt && rhsInt.value() == 0) + return lhs; + if (!rhsInt && lhsInt && lhsInt.value() == 0) + return rhs; + + auto lhsValue = dyn_cast(lhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + auto rhsValue = dyn_cast(rhs); + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult subOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() - rhsInt.value()); + + if (!lhsInt && rhsInt && rhsInt.value() == 0) + return lhs; + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() * rhsInt.value()); + + if (lhsInt) { + if (lhsInt.value() == 0) + return lhs; + if (lhsInt.value() == 1) + return rhs; + } + if (rhsInt) { + if (rhsInt.value() == 0) + return rhs; + if (rhsInt.value() == 1) + return lhs; + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult divOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (rhsInt && rhsInt.value() == 0) { + emitError(loc) << "cannot div 0!"; + return OpFoldResult(); + } + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() / rhsInt.value()); + + if (lhsInt) { + if (lhsInt.value() == 0) + return lhs; + } + + if (rhsInt) { + if (rhsInt.value() == 1) + return lhs; + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult remOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (rhsInt && rhsInt.value() == 0) { + emitError(loc) << "cannot remainder by 0!"; + return OpFoldResult(); + } + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() % rhsInt.value()); + + if (lhsInt) { + if (lhsInt.value() == 0) + return lhs; + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult minOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + if (lhsInt && rhsInt) + return b.getIndexAttr(std::min(lhsInt.value(), rhsInt.value())); + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +OpFoldResult maxOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, + const Location &loc, OpBuilder &b) { + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + if (lhsInt && rhsInt) + return b.getIndexAttr(std::max(lhsInt.value(), rhsInt.value())); + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); +} + +LogicalResult +addReduceWithIndexAttrIfNeeded(ConversionPatternRewriter &rewriter, + linalg::ReduceOp reduceOp) { + // To verify whether the operation of the reduceOp is ReduceWithIndex + // TODO: maybe a better way of judging? + Block &body = reduceOp.getCombiner().front(); + auto yieldOp = dyn_cast(body.getTerminator()); + + auto yieldValue = yieldOp.getValues(); + if (yieldValue.size() == 0) { + return failure(); + } + + auto opIter = reduceOp.getBody()->without_terminator().begin(); + auto cmpMaskOp = dyn_cast(*opIter); + const StringRef reduceRef = "reduce_mode"; + if (cmpMaskOp) { + if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OGT) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); + } else if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OLT) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); + } + } + + auto cmpMaskIOp = dyn_cast(*opIter); + if (cmpMaskIOp) { + if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::sgt || + cmpMaskIOp.getPredicate() == arith::CmpIPredicate::ugt) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); + } else if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::slt || + cmpMaskIOp.getPredicate() == arith::CmpIPredicate::ult) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); + } + } + + return success(); +} + +// Fold layout constant info to attr, otherwise convert to index type value +OpFoldResult getOpFoldResultOfLayoutInfo(Value value, OpBuilder &builder) { + OpFoldResult constantFold = getAsOpFoldResult(value); + if (llvm::isa(constantFold)) { + assert(isa(constantFold.get())); + return constantFold; + } + + if (!isa(value.getType())) + llvm_unreachable("Illegal data type when parse block data layout info"); + + if (!isa(value.getType())) { + if (value.getType().isInteger(/*width*/ 1)) + value = builder.create( + value.getLoc(), builder.getIndexType(), value); + else + value = builder.create(value.getLoc(), + builder.getIndexType(), value); + } + + return value; +} + +} // namespace mlir diff --git a/third_party/ascend/triton-adapter/safe_compile.cmake b/third_party/ascend/triton-adapter/safe_compile.cmake new file mode 100644 index 000000000..2b29614c7 --- /dev/null +++ b/third_party/ascend/triton-adapter/safe_compile.cmake @@ -0,0 +1,10 @@ +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIE") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIE") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-strong") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstack-protector-strong") +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-z,now -pie -s") +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,now -s") +set(CMAKE_SKIP_RPATH TRUE) +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) +unset(CMAKE_INSTALL_RPATH) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/tools/CMakeLists.txt b/third_party/ascend/triton-adapter/tools/CMakeLists.txt new file mode 100644 index 000000000..fff219ac0 --- /dev/null +++ b/third_party/ascend/triton-adapter/tools/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton-adapter-opt) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt new file mode 100644 index 000000000..04dcc2de6 --- /dev/null +++ b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/CMakeLists.txt @@ -0,0 +1,23 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + +add_llvm_executable(triton-adapter-opt triton-adapter-opt.cpp PARTIAL_SOURCES_INTENDED) + +# TODO: what's this? +llvm_update_compile_flags(triton-adapter-opt) +target_link_libraries(triton-adapter-opt PRIVATE + TritonToAnnotation + TritonToHIVM + DiscreteMaskAccessConversion + TritonToLinalg + TritonToUnstructure + TritonTransforms + ${dialect_libs} + ${conversion_libs} + TritonGPUTransforms + MLIROptLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-adapter-opt) \ No newline at end of file diff --git a/third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp new file mode 100644 index 000000000..49c9436e9 --- /dev/null +++ b/third_party/ascend/triton-adapter/tools/triton-adapter-opt/triton-adapter-opt.cpp @@ -0,0 +1,40 @@ +#include "TritonToAnnotation/Passes.h" +#include "TritonToHIVM/Passes.h" +#include "DiscreteMaskAccessConversion/Passes.h" +#include "TritonToLinalg/Passes.h" +#include "TritonToUnstructure/Passes.h" +#include "bishengir/InitAllDialects.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +int main(int argc, char **argv) { + // Register all dialects. + mlir::DialectRegistry registry; + registry.insert(); + bishengir::registerAllDialects(registry); + + // Register all passes. + mlir::triton::registerTritonToLinalgPass(); + mlir::triton::registerTritonToAnnotationPass(); + mlir::triton::registerTritonToHIVMPass(); + mlir::triton::registerDiscreteMaskAccessConversionPass(); + mlir::triton::registerBubbleUpOperationPass(); + mlir::triton::registerTritonToUnstructurePass(); + + return mlir::asMainReturnCode( + mlir::MlirOptMain(argc, argv, "Triton-Adapter test driver\n", registry)); +} diff --git a/third_party/ascend/triton-adapter/triton_adapter.cc b/third_party/ascend/triton-adapter/triton_adapter.cc new file mode 100644 index 000000000..6f5ebd97d --- /dev/null +++ b/third_party/ascend/triton-adapter/triton_adapter.cc @@ -0,0 +1,6 @@ +#include + +namespace py = pybind11; + +// compilation goes to triton-adapter-opt, do nothing here +void init_triton_triton_adapter(py::module &&m) {} \ No newline at end of file diff --git a/third_party/ascend/triton_ascend.cpp b/third_party/ascend/triton_ascend.cpp new file mode 100644 index 000000000..d323f95af --- /dev/null +++ b/third_party/ascend/triton_ascend.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include "triton-shared/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimental.h" + +#define PY_SSIZE_T_CLEAN +#include +namespace py = pybind11; + +void init_triton_ascend_passes_convert(py::module &&m) { + ADD_PASS_WRAPPER_0("add_triton_to_linalg_pipeline", + mlir::triton::createTritonToLinalgExperimentalPass); +} + +// register ascend passes to triton +void init_triton_ascend(py::module &&m) { + auto passes = m.def_submodule("passes"); + init_triton_ascend_passes_convert(passes.def_submodule("convert")); +} \ No newline at end of file diff --git a/third_party/ascend/triton_patch/include/CMakeLists.txt b/third_party/ascend/triton_patch/include/CMakeLists.txt new file mode 100644 index 000000000..109c292fe --- /dev/null +++ b/third_party/ascend/triton_patch/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton) diff --git a/third_party/ascend/triton_patch/include/runtime/libentry/libentry.h b/third_party/ascend/triton_patch/include/runtime/libentry/libentry.h new file mode 100644 index 000000000..5b9ff3388 --- /dev/null +++ b/third_party/ascend/triton_patch/include/runtime/libentry/libentry.h @@ -0,0 +1,59 @@ +#ifndef LIBENTRY_H +#define LIBENTRY_H + +#include +#include +#include +#include + +#include +#include + +namespace py = pybind11; + +using KeyType = py::tuple; + +namespace libentry { + +class ArgProcessor { +public: + ArgProcessor(int div) : divisibility_(div){}; + + void classifyArguments( + const py::list& args, + const py::dict& kwargs, + const py::list& jit_params, + const std::unordered_set& specialize_indices, + const std::unordered_set& do_not_specialize_indices); + + KeyType generateKey(); + + py::list getKArgs(); + +private: + py::list spec_args_; // specialize args + py::list dns_args_; // do not specialize args + py::list const_args_; // constexpr args + py::list k_args_; // kernel args + int divisibility_; // 对齐 +}; + +} // namespace libentry + +PYBIND11_MODULE(libentry_ascend, m) { + py::class_(m, "ArgProcessor") + .def(py::init()) + .def("classify_arguments", &libentry::ArgProcessor::classifyArguments, + py::arg("args"), + py::arg("kwargs"), + py::arg("jit_params"), + py::arg("specialize_indices"), + py::arg("do_not_specialize_indices"), + "classify arguments") + .def("get_k_args", &libentry::ArgProcessor::getKArgs, + "get kernel") + .def("generate_key", &libentry::ArgProcessor::generateKey, + "generate kernel cache key"); +} + +#endif \ No newline at end of file diff --git a/third_party/ascend/triton_patch/include/triton/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/CMakeLists.txt new file mode 100644 index 000000000..0ca0f41c5 --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt new file mode 100644 index 000000000..5e601271e --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Triton) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..990e3b68f --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,39 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +file(RELATIVE_PATH patch_rel_dir "${CMAKE_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}") +string(REPLACE "triton_patch" "third_party/triton" triton_rel_dir "${patch_rel_dir}") + set(triton_abs_dir "${TRITON_ROOT_DIR}/include/triton/Dialect/Triton/IR") # message(STATUS "triton_abs_dir: ${triton_abs_dir}") + +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +# add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +# add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) + +# TODO: When upgrading to Triton 3.4.0, enable the commented line below and remove the current line. +# set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypes.td) +set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypeInterfaces.td) +mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) + +# TODO: When upgrading to Triton 3.4.0, enable the commented line below and remove the current line. +# set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonOpInterfaces.td) +set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td) +mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) + +add_public_tablegen_target(Patched_TritonTableGen) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt.bak b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt.bak new file mode 100644 index 000000000..9b004a8bd --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt.bak @@ -0,0 +1,40 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +file(RELATIVE_PATH patch_rel_dir "${CMAKE_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}") +string(REPLACE "triton_patch" "third_party/triton" triton_rel_dir "${patch_rel_dir}") +set(triton_abs_dir "${CMAKE_SOURCE_DIR}/${triton_rel_dir}") +# message(STATUS "triton_abs_dir: ${triton_abs_dir}") + +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +# add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +# add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) + +# TODO: When upgrading to Triton 3.4.0, enable the commented line below and remove the current line. +# set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypes.td) +set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + +set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypeInterfaces.td) +mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) + +# TODO: When upgrading to Triton 3.4.0, enable the commented line below and remove the current line. +# set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonOpInterfaces.td) +set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td) +mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) + +add_public_tablegen_target(Patched_TritonTableGen) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/Dialect.h b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/Dialect.h new file mode 100644 index 000000000..c0b0885ed --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/Dialect.h @@ -0,0 +1,75 @@ +// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file. +#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITON_IR_DIALECT_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" + +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/OpsEnums.h.inc" +#include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "mlir/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { + +struct GlobalMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +class DialectInferLayoutInterface : public DialectInterface::Base { +public: + DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult inferTransOpEncoding(Attribute operandEncoding, ArrayRef order, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, Attribute &resultEncoding, + std::optional location) const = 0; + + // Note: This function only verifies the operand encoding. It doesn't infer + // the result encoding. + virtual LogicalResult inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, Attribute retEncoding, + std::optional location) const = 0; + + // Tries to compute the encoding for the result of a reshape operation that + // makes the reshape a "nop", i.e. the same GPU threads contain the same + // elements as before the reshape. Note that this is not always possible (in + // which case you'd need to choose a different layout for the input to the + // reshape). + virtual LogicalResult inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + // Verify that the encoding are compatible to be used together in a dot + // operation + virtual LogicalResult verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const = 0; +}; + +} // namespace triton +} // namespace mlir + +#endif // TRITON_IR_DIALECT_H_ diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/OpInterfaces.h b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/OpInterfaces.h new file mode 100644 index 000000000..0745bb045 --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/OpInterfaces.h @@ -0,0 +1,25 @@ +// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file. +#ifndef TRITON_IR_OP_INTERFACES_H_ +#define TRITON_IR_OP_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { + +namespace triton { + +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op); + +LogicalResult verifyDotOpInterface(Operation *op); + +} // namespace impl + +} // namespace triton +} // namespace mlir + +#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc" + +#endif // TRITON_IR_OP_INTERFACES_H_ diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td new file mode 100644 index 000000000..342400f76 --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -0,0 +1,139 @@ +#ifndef TRITON_ATTR_DEFS +#define TRITON_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Attributes for LoadOp and StoreOp +def TT_CacheModifierAttr : I32EnumAttr< + "CacheModifier", "", + [ + I32EnumAttrCase<"NONE", 1, "none">, + I32EnumAttrCase<"CA", 2, "ca">, + I32EnumAttrCase<"CG", 3, "cg">, + I32EnumAttrCase<"WB", 4, "wb">, + I32EnumAttrCase<"CS", 5, "cs">, + I32EnumAttrCase<"WT", 6, "wt">, + I32EnumAttrCase<"CV", 7, "cv">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_EvictionPolicyAttr : I32EnumAttr< + "EvictionPolicy", "", + [ + I32EnumAttrCase<"NORMAL", 1, "evict_normal">, + I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, + I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_PaddingOptionAttr : I32EnumAttr< + "PaddingOption", "", + [ + I32EnumAttrCase<"PAD_ZERO", 1, "zero">, + // We can not set the string value to "NAN" because it is a keyword in C++ + I32EnumAttrCase<"PAD_NAN", 2, "nan"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +// atomic +def TT_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"FADD", 5, "fadd">, + I32EnumAttrCase<"MAX", 6, "max">, + I32EnumAttrCase<"MIN", 7, "min">, + I32EnumAttrCase<"UMAX", 8, "umax">, + I32EnumAttrCase<"UMIN", 9, "umin">, + I32EnumAttrCase<"XCHG", 10, "exch"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Program ID dimensions. +def TT_ProgramDim : I32EnumAttr< + "ProgramIDDim", "", + [ + I32EnumAttrCase<"X", 0, "x">, + I32EnumAttrCase<"Y", 1, "y">, + I32EnumAttrCase<"Z", 2, "z">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Rounding mode. +def TT_RoundingModeAttr : I32EnumAttr< + "RoundingMode", "", + [ + I32EnumAttrCase<"RTZ", 0, "rtz">, + I32EnumAttrCase<"RTNE", 1, "rtne">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// PropagateNan. +def TT_PropagateNanAttr : I32EnumAttr< + "PropagateNan", "", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"ALL", 0xFFFF, "all">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// InputPrecision +def TT_InputPrecisionAttr : I32EnumAttr< + "InputPrecision", "", + [ + I32EnumAttrCase<"TF32", 0, "tf32">, + I32EnumAttrCase<"TF32x3", 1, "tf32x3">, + I32EnumAttrCase<"IEEE", 2, "ieee">, + I32EnumAttrCase<"HF32", 3, "hf32">, + ]>{ + let cppNamespace = "::mlir::triton"; +} + +// Type for F8F6F4 kind of floats. +def TT_F8F6F4TypeAttr : I32EnumAttr< + "F8F6F4Type", "", + [ + I32EnumAttrCase<"E4M3", 0, "e4m3">, + I32EnumAttrCase<"E5M2", 1, "e5m2">, + I32EnumAttrCase<"E2M3", 2, "e2m3">, + I32EnumAttrCase<"E3M2", 3, "e3m2">, + I32EnumAttrCase<"E2M1", 4, "e2m1">, + I32EnumAttrCase<"BF16", 5, "bf16">, + I32EnumAttrCase<"FP16", 6, "fp16"> + + ]>{ + let cppNamespace = "::mlir::triton"; +} + +#endif diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td new file mode 100644 index 000000000..720bfbd7b --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td @@ -0,0 +1,37 @@ +// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file. +#ifndef TRITON_OP_INTERFACES +#define TRITON_OP_INTERFACES + +include "mlir/IR/OpBase.td" + + +def TT_DescriptorOpInterface : OpInterface<"DescriptorOpInterface"> { + let description = [{ + Common interface to get the descriptor argument from an operation on tensor descriptors. + }]; + + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get the descriptor", + /*retType=*/"::mlir::TypedValue", + /*methodName=*/"getDesc", + /*args=*/(ins)>, + ]; +} + +def TT_DescriptorStoreLikeOpInterface : OpInterface<"DescriptorStoreLikeOpInterface", [TT_DescriptorOpInterface]> { + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/"Get Source tensor", + /*retType=*/"::mlir::TypedValue", + /*methodName=*/"getSrc", + /*args=*/(ins)>, + ]; +} + + +#endif // TRITON_OP_INTERFACES diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td new file mode 100644 index 000000000..5c4dd8387 --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td @@ -0,0 +1,1435 @@ +#ifndef TRITON_OPS +#define TRITON_OPS + +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" + + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +// +// Op Base +// +class TT_Op traits = []> : + Op { +} + +// +// Cast Ops +// +// Use cast ops in arith: +// bitcast +// fptoui, fptosi, uitofp, sitofp, +// extf, tructf, +// extui, extsi, tructi +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast int64 to pointer"; + + let arguments = (ins TT_I64Like:$src); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast pointer to int64"; + + let arguments = (ins TT_PtrLike:$src); + + let results = (outs TT_I64Like:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +// arith.bitcast doesn't support pointers +def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast between types of the same bitwidth"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + // TODO: Add verifier +} + +def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Floating point casting for custom types"; + + let description = [{ + Floating point casting for custom types (F8), and non-default rounding modes. + + F8 <-> FP16, BF16, FP32, FP64 + }]; + + let arguments = ( + ins TT_FloatTensor:$src, + OptionalAttr:$rounding + ); + + let results = (outs TT_FloatTensor:$result); + + let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)"; + + let hasVerifier = 1; +} + +// +// Arithmetic Ops +// + +def TT_ClampFOp : TT_Op<"clampf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Clamp operation for floating point types"; + + let description = [{ + Clamp operation for floating point types. + + The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. + }]; + + let arguments = ( + ins + TT_FloatLike:$x, + TT_FloatLike:$min, + TT_FloatLike:$max, + TT_PropagateNanAttr:$propagateNan + ); + + let results = (outs TT_FloatLike:$result); + + // List $propagateNan explicitly rather than relying on attr-dict to pick it + // up, because if it's inside attr-dict, its value will be printed as a + // number rather than as a meaningful string. + let assemblyFormat = "$x `,` $min `,` $max `,` `propagateNan` `=` $propagateNan attr-dict `:` type($result)"; +} + +// +// Math Ops +// + +def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise sqrt for floating point types"; + + let description = [{ + Precise sqrt for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x attr-dict `:` type($x)"; +} + +def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise div for floating point types"; + + let description = [{ + Precise div for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$y); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Most significant N bits of the 2N-bit product of two integers"; + + let description = [{ + Most significant N bits of the 2N-bit product of two integers. + }]; + + let arguments = (ins TT_IntLike:$x, TT_IntLike:$y); + + let results = (outs TT_IntLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +// +// Pointer Arith Ops +// +def TT_AddPtrOp : TT_Op<"addptr", + [Pure, + Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; +} + +def TT_AdvanceOp : TT_Op<"advance", + [Pure, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let summary = "Advance a tensor pointer by offsets"; + + let arguments = (ins TT_TensorPtr:$ptr, Variadic:$offsets); + + let results = (outs TT_TensorPtr:$result); + + let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let hasFolder = 1; +} + +// +// Load/Store Ops +// +def TT_LoadOp : TT_Op<"load", [ + SameLoadStoreOperandsAndResultShape, + SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">, + TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Load from a tensor of pointers or from a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + Optional:$mask, + Optional:$other, + + DefaultValuedAttr{}">:$boundaryCheck, + OptionalAttr:$padding, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let results = (outs TT_Type:$result); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor pointer with boundary check and padding + OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask and other + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A utility function to build the operation with all attributes + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + ]; + + // Specify `cacheModifier` and `evictionPolicy` explicitly in the + // assemblyFormat instead of as part of attr-dict so that they get printed + // as strings rather than opaque integers. + // + // Note there's no comma between `other` and `cacheModifier` and between + // `cacheModifier` and `evictionPolicy`. This is due to an apparent + // limitation in the MLIR custom-format parser. In oilist, the initial + // keywords of each clause have to be unique, so they can't be `,`. + // + // Even if we gave up on order-independence and used vanilla optional + // clauses, the format (`,` `foo` `=` $foo^)? (`,` `bar` `=` $bar^)? will + // not match the string ", bar = 0" because after the initial comma (first + // token of the first optional clause) we expect to see "foo". + let assemblyFormat = [{ + $ptr (`,` $mask^)? (`,` $other^)? + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +def TT_StoreOp : TT_Op<"store", [ + SameLoadStoreOperandsShape, + SameLoadStoreOperandsEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"value type matches ptr type", "ptr", "value", + "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", + "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Store by a tensor of pointers or by a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + TT_Type:$value, + Optional:$mask, + DefaultValuedAttr{}">:$boundaryCheck, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)>, + // A tensor pointer with boundary check + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$boundaryCheck, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)> + ]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between mask, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $ptr `,` $value (`,` $mask^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +// +// Atomic Ops +// +def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"mask type matches value type", + "val", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "atomic rmw"; + + let description = [{ + load data at $ptr, do $rmw_op with $val, and store result to $ptr. + + return old value at $ptr + }]; + + let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr, + TT_Type:$val, Optional:$mask, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $atomic_rmw_op, $sem, and $scope rather than relying on + // attr-dict so they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` + functional-type(operands, $result) + }]; +} + +def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { + let summary = "atomic cas"; + + let description = [{ + compare $cmp with data $old at location $ptr, + + if $old == $cmp, store $val to $ptr, + + else store $old to $ptr, + + return $old + }]; + + let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $sem and $scope rather than relying on attr-dict so + // they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` + functional-type(operands, $result) + }]; +} + +// +// Shape Manipulation Ops +// +def TT_SplatOp : TT_Op<"splat", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "splat"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; +} + +def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "expand_dims"; + + let arguments = (ins TT_Tensor:$src, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +def TT_ReshapeOp : TT_Op<"reshape", [Pure, + SameOperandsAndResultElementType]> { + let summary = "reinterpret a tensor to a different shape. It may change elements order if the attribute is set."; + let description = [{ + reinterpret a tensor to a different shape. + + If allow_reorder is set the compiler is free to change the order of + elements to generate more efficient code. + + If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. + The compiler is still free to change it for better performance. + }]; + let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)"; + let hasCanonicalizeMethod = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + +def TT_BroadcastOp : TT_Op<"broadcast", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "broadcast a tensor"; + + let description = [{ + For a given tensor, broadcast changes one or more dimensions with size 1 + to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot + change the size of a non-1 dimension. + }]; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + +// cat is not `pure` because it may reorder elements +def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, + SameTypeOperands, + SameOperandsAndResultElementType]> { + let summary = "concatenate 2 tensors"; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_JoinOp : TT_Op<"join", [ + NoMemoryEffect, SameTypeOperands, + DeclareOpInterfaceMethods, +]> { + let summary = "join two tensors along a new, minor dimension"; + let description = [{ + For example, if the two input tensors are 4x8xf32, returns a tensor of + shape 4x8x2xf32. + + Because Triton tensors always have a power-of-two number of elements, + the two input tensors must have the same shape. + }]; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_SplitOp : TT_Op<"split", [ + NoMemoryEffect, + DeclareOpInterfaceMethods, + TypesMatchWith<"outLHS and outRHS types match", + "outLHS", "outRHS", "$_self">, +]> { + let summary = "splits a tensor into two, along its last dimension"; + let description = [{ + The input must be a tensor whose last dimension has size 2. Returns two + tensors, src[..., 0] and src[..., 1]. + + For example, if the input shape is 4x8x2xf32, returns two tensors of + shape 4x8xf32. + }]; + + let arguments = (ins TT_Tensor:$src); + let results = (outs TT_Tensor:$outLHS, TT_Tensor:$outRHS); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($outLHS)"; +} + +def TT_SortOp : TT_Op<"sort", [ + NoMemoryEffect, + DeclareOpInterfaceMethods +]> { + let summary = "Sorts a tensor along a given dimension and returns sorted values."; + let description = [{ + Sorts the elements of the input tensor along the specified dimension. + Returns one tensor: + The sorted tensor (same shape and element type as input). + }]; + + let arguments = (ins + TT_Tensor:$src, // Input tensor + I64Attr:$dim, // Dimension to sort along + BoolAttr:$descending // Sort order + ); + + let results = (outs + TT_Tensor:$sorted // Sorted values + ); + + let assemblyFormat = "$src `,` $dim `,` $descending attr-dict `:` type($src) `->` type($sorted)"; +} + +def TT_TransOp : TT_Op<"trans", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + + let summary = "rearrange the dimensions of a tensor"; + let description = [{ + For example, given a tensor x with shape [1,2,4], transpose(x) with + order=[2,0,1] rearranges the tensor to have shape [4,1,2]. + + Although this op is called "trans", it implements both tl.trans() and + tl.permute(). ("permute" might be a better name, but it's called "trans" + because originally it only supported 2D tensors.) + + ## Implementation note on encodings: + + In the TritonGPU dialect (and probably others), an encoding is chosen for + this op's output so it's a nop from the perspective of code generation. + + For example, suppose tensor x has an encoding such that GPU thread [i,j,k] + has a register containing element [i,j,k] of the tensor. Now we transpose + x with order [2,1,0], i.e. we reverse the order of its dimensions. In + TritonGPU, we will choose a layout for the output of the transpose so that + GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the + same element it had before! All we've done is "rename" the element that + thread [i,j,k] has. + + The "real" transpose -- i.e. moving data between GPU threads -- occurs in + convertLayout ops that appear before and/or after the operation. + + We do this so that you can chain multiple data-movement ops (e.g. + transpose+reshape+concat) without going to shared memory after each one. + }]; + + let arguments = ( + ins TT_TensorOrMemDesc:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorOrMemDesc:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// SPMD Ops +// +def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +// +// Dot Op +// +def TT_DotOp : TT_Op<"dot", [Pure, + DeclareOpInterfaceMethods, + DotLike, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC + when the inputs are f32. It can be one of: tf32, tf32x3, ieee. + tf32: use TC with tf32 ops. + tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + ieee: don't use TC, implement dot in software. + If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. + }]; + + let arguments = ( + ins + TT_FpIntTensor:$a, + TT_FpIntTensor:$b, + TT_FpIntTensor:$c, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs TT_FpIntTensor:$d); + + // attr-dict prints enums as integers. To get inputPrecision printed as a + // string, we need to specify it explicitly. + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; + let hasVerifier = 1; +} + + +// +// DotScaled Op +// +def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, + DotLike, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot_scaled"; + + let description = [{ + $d = matrix_multiply(scale($lhs, $lhs_scale), scale($rhs, $rhs_scale)) + $c. + Where scale(x, s) is a function that applies the scale per block following microscaling spec. + }]; + + let arguments = ( + ins + // inputs are integer types as they are packed types and we currently + // don't have a representation for those. + TT_FpIntTensor:$lhs, + TT_FpIntTensor:$rhs, + TT_FloatTensor:$c, + TT_IntTensor:$lhs_scale, + Optional:$rhs_scale, + TT_F8F6F4TypeAttr:$lhs_type, + TT_F8F6F4TypeAttr:$rhs_type + ); + + let results = (outs TT_FloatTensor:$d); + + // Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file + let assemblyFormat = [{ + $lhs `,` $rhs `,` $c `,` $lhs_scale (`,` $rhs_scale^)? + `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict + `:` type($lhs) `,` type($rhs) `,` type($c) `,` type($lhs_scale) (`,` type($rhs_scale)^)? `->` type($d) + }]; +} + +// +// Reduce Op +// +def TT_ReduceOp: TT_Op<"reduce", + [Pure, + SameOperandsShape, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ReduceReturnOp: TT_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +// +// Scan Op +// +def TT_ScanOp: TT_Op<"scan", + [Pure, + SameOperandsAndResultEncoding, + SameOperandsAndResultShape, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Associative scan using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis, BoolAttr:$reverse); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "bool":$reverse)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ScanReturnOp: TT_Op<"scan.return", + [HasParent<"ScanOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for scan operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + + +// +// External Elementwise op +// +def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods, + ConditionallySpeculatable]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; + + let extraClassDeclaration = [{ + // Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + +} + +// +// Make Range Op +// +def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { + let summary = "make range"; + + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + + // WARNING: MLIR generates getStart()/getEnd() functions which return + // uint32_t, even though these arguments are to be interpreted as *signed* + // int32 values. If this matters, use get{Start,End}Attr().getInt(), which + // return int64_t. + let arguments = (ins I32Attr:$start, I32Attr:$end); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// ElementwiseInlineAsm Op +// +def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ + Elementwise, + SameOperandsAndResultEncoding, + DeclareOpInterfaceMethods +]> { + let summary = "inline assembly applying an elementwise operation to a group of packed elements."; + let description = [{ + Runs an inline asm block to generate one or more tensors. + + The asm block is given `packed_element` elements at a time. Exactly which + elems it receives is unspecified. + }]; + + let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic>:$args); + let results = (outs Variadic:$result); + + let assemblyFormat = [{ + $asm_string attr-dict ($args^ `:` type($args))? `->` type($result) + }]; + + let hasVerifier = 1; +} + +// +// Histogram Op +// +def TT_HistogramOp : TT_Op<"histogram", [Pure]> { + let summary = "return a histgram of the inputs."; + let description = [{ + Return the histogram of the input tensor. The number of bins is equal to + the dimension of the output tensor. Each bins has a width of 1 and bins + start at 0. + }]; + + let arguments = (ins TT_IntTensor:$src); + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = [{ + $src attr-dict `:` type($src) `->` type($result) + }]; +} + +// +// Gather Op +// +def TT_GatherOp : TT_Op<"gather", [Pure, + DeclareOpInterfaceMethods]> { + let summary = "local gather operation"; + let description = [{ + Gather elements from the input tensor using the indices tensor along a + single specified axis. The output tensor has the same shape as the indices + tensor. The input and indices tensors must have the same number of + dimension, and each dimension of the indices tensor that is not the gather + dimension cannot be greater than the corresponding dimension in the input + tensor. + }]; + + let arguments = (ins TT_Tensor:$src, TT_IntTensor:$indices, I32Attr:$axis); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `[` $indices `]` attr-dict `:` + functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +// +// Print Op +// +def TT_PrintOp : TT_Op<"print", [SameVariadicOperandSize, MemoryEffects<[MemWrite]>]> { + let arguments = ( + ins + StrAttr:$prefix, + BoolAttr:$hex, + Variadic>:$args, + DenseI32ArrayAttr:$isSigned + ); + let summary = "Device-side print, as in CUDA for debugging"; + let description = [{ + `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. + format are generated automatically from the arguments. + }]; + let assemblyFormat = [{ + $prefix attr-dict (`:` $args^ `:` type($args))? + }]; +} + +// +// Assert Op +// +def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "Device-side assert, as in CUDA for correctness checking"; + let description = [{ + `tt.assert` takes a condition tensor and a message string. + If the condition is false, the message is printed, and the program is aborted. + }]; + let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message); + let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; +} + +// +// Make Tensor Pointer Op +// +def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", + [Pure, + SameVariadicOperandSize, + TypesMatchWith<"infer pointer type from the result type", + "result", "base", + "getPointerType(getElementTypeOfTensorPointerType($_self), getAddressSpace($_self))">]> { + let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified"; + + let description = [{ + `tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a + pointer to the block tensor, e.g. returns a type of `tt.ptr>`. + }]; + + // TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints. + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorPtr:$result); + + // TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly + // Add additional `[]` to increase readability and split variadic lists + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins + "Value":$base, + "ValueRange":$shape, + "ValueRange":$strides, + "ValueRange":$offsets, + "ArrayRef":$tensorShape, + "ArrayRef":$order + )> + ]; +} + +// +// Make Tensor Descriptor Op +// +def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ + Pure, + SameVariadicOperandSize, +]> { + let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size"; + + let description = [{ + `tt.make_tensor_descriptor` takes both meta information of the parent tensor and the block size, + and returns a descriptor object which can be used to load/store from the tensor in global memory. + }]; + + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides + ); + + let results = (outs TT_TensorDescType:$result); + + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)"; + + let builders = [ + OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef":$blockShape, "bool":$isSignedInteger)> + ]; + + let extraClassDeclaration = [{ + ArrayRef getTensorShape() { + return getType().getBlockType().getShape(); + } + }]; +} + +def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [TT_DescriptorOpInterface]> { + let summary = "Load from descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA load operation on targets supporting it. + `desc` is a tensor descriptor object. + The destination tensor type and shape must match the descriptor otherwise the result is undefined. + }]; + let arguments = (ins + Arg]>:$desc, + Variadic:$indices, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc `[` $indices `]` + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` qualified(type($desc)) `->` type($result) + }]; + + let hasVerifier = 1; +} + +def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [TT_DescriptorStoreLikeOpInterface]> { + let summary = "store value based on descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc` is a tensor descriptor object. + The shape and types of `src` must match the descriptor otherwise the result is undefined. + }]; + let arguments = (ins + Arg, MemWrite]>:$desc, + TT_Tensor:$src, + Variadic:$indices + ); + + let assemblyFormat = [{ + $desc `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc)) `,` type($src) + }]; + let hasVerifier = 1; +} + +// The following ops, including `call`, `func`, and `return` are copied and modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +// We could revert it back once MLIR has a better inliner interface. +// +// Function Ops +// +def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `tt.call` operation represents a direct call to a function that is + within the same symbol scope as the call. The operands and result types of + the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, // triton_v3.3.x + OptionalAttr:$res_attrs); // triton_v3.3.x + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); + } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + + // Required by CallOpInterface. + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). An external function declaration (used when + referring to a function declared in some other module) has no body. While + the MLIR textual form provides a nice inline syntax for function arguments, + they are internally represented as “block arguments” to the first block in + the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // External function definitions. + tt.func @abort() + tt.func @scribble(i32, i64, memref) -> f64 + + // A function that returns its argument twice: + tt.func @count(%x: i64) -> (i64, i64) + attributes {fruit: "banana"} { + return %x, %x: i64, i64 + } + + // A function with an argument attribute + tt.func @example_fn_arg(%x: i32 {swift.self = unit}) + + // A function with a result attribute + tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + + // A function with an attribute + tt.func @example_fn_attr() attributes {dialectName.attrName = false} + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getFunctionType().getResults(); } + + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; +} + +def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `tt.return` operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + The operand number and types must match the signature of the function + that contains the operation. + + Example: + + ```mlir + tt.func @foo() : (i32, f8) { + ... + tt.return %0, %1 : i32, f8 + } + ``` + }]; + + let arguments = (ins Variadic:$srcs); + + let builders = [OpBuilder<(ins), [{ + build($_builder, $_state, std::nullopt); + }]>]; + + let assemblyFormat = "attr-dict ($srcs^ `:` type($srcs))?"; + let hasVerifier = 1; +} + + +def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ + MemoryEffects<[MemRead]>]> { + let summary = "Load from descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA load operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The destination tensor type and shape must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + Variadic:$indices, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` qualified(type($desc_ptr)) `->` type($result) + }]; +} + +def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ + MemoryEffects<[MemRead, MemWrite]>]> { + let summary = "store value based on descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The shape and types of `src` must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_Tensor:$src, + Variadic:$indices + ); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc_ptr)) `,` type($src) + }]; +} + +def TT_ExperimentalTensormapCreateOp: TT_Op< + "experimental_tensormap_create", + [ + MemoryEffects<[MemRead, MemWrite]>, + AttrSizedOperandSegments, + ] +> { + let summary = "Create a new TMA descriptor on device"; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_PtrType:$global_address, + Variadic:$box_dim, + Variadic:$global_dim, + Variadic:$global_stride, + Variadic:$element_stride, + ConfinedAttr]>:$elem_type, + ConfinedAttr]>:$interleave_layout, + ConfinedAttr]>:$swizzle_mode, + ConfinedAttr]>:$fill_mode + ); + let extraClassDeclaration = [{ + int32_t getRank() { + return getBoxDim().size(); + } + }]; + let assemblyFormat = [{ + $desc_ptr `,` $global_address `,` + `[` $box_dim `]` `,` + `[` $global_dim `]` `,` + `[` $global_stride `]` `,` + `[` $element_stride `]` + attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op< + "experimental_tensormap_fenceproxy_acquire", + [MemoryEffects<[MemWrite]>] +> { + let summary = "Acquire fence on a tensormap object"; + let arguments = (ins TT_PtrType:$desc_ptr); + let assemblyFormat = [{ + $desc_ptr attr-dict `:` qualified(type($desc_ptr)) + }]; +} + + +// +// Annotation Op +// +def TT_AnnotationOp : TT_Op<"annotation", [Pure, MemoryEffects<[MemWrite]>]> { + let summary = "Annotate a tensor with key-value attribute pairs"; + let description = [{ + `tt.annotation` operation can be used to annotate a tensor with + key-value attribute pairs. + + Example: + ```mlir + tt.annotation %target {key : val} + ``` + }]; + let arguments = (ins TT_Tensor:$src); + let assemblyFormat = [{ + $src attr-dict `:` type($src) + }]; +} + +// +// Custom Op +// +def TT_CustomOp : TT_Op<"custom", [Pure, MemoryEffects<[MemWrite]>]> { + let summary = "self-defined custom operation"; + let description = [{ + `tt.custom` triton custom op is designed to pass self-defined custom operation. + + Example: + ```tt.custom {str_args = ["sync_block_wait", "cube"]} loc(#loc12) + ``` + }]; + let arguments = (ins StrAttr:$op_name, ArrayAttr:$str_args, Variadic:$args); + + let assemblyFormat = "$op_name attr-dict ($args^ `:` type($args))?"; +} + +#endif // Triton_OPS diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonTypes.td b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonTypes.td new file mode 100644 index 000000000..5f9f19a07 --- /dev/null +++ b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -0,0 +1,178 @@ +// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file. +#ifndef TRITON_TYPES +#define TRITON_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" + +// +// Types +// +class TritonTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def TT_FloatTensor : RankedTensorOf<[TT_Float]>; +def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def TT_BoolTensor : RankedTensorOf<[I1]>; +def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; + +// Integer Type +def I4 : I<4>; +def TT_Int : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">; +def TT_IntTensor : RankedTensorOf<[TT_Int]>; +def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>; + +// I32 Type +// TT_I32 -> I32 +// TT_I32Tensor -> I32Tensor +def TT_I32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// TT_I64 -> I64 +// TT_I64Tensor -> I64Tensor +def TT_I64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class TT_PtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `TT_PtrOf`) +def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def TT_Ptr : TT_PtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def TT_PtrTensor : RankedTensorOf<[TT_Ptr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>; + +// Tensor Type +def TT_FpIntTensor : RankedTensorOf<[TT_Float, TT_Int]>; +def TT_Tensor : RankedTensorOf<[TT_Float, TT_Int, TT_Ptr]>; + +// Pointer Type to Tensor Type: `ptr>` +def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>; + +// Any Type in Triton IR +def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>; + +// Memory descriptor type. +def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { + let summary = "memory descriptor type (`::mlir::triton::MemDescType`) in Triton IR type system"; + + let description = [{ + Memory descriptor contains a base pointer (scalar) and a descriptor of the memory. + If mutable memory is false that means the memory is constant and can only be allocated and stored once. + A constant memory allocation is different than a tensor as it can have multiple views and the descriptor + can be changed without changing the underlying memory. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutable_memory + ); + let extraClassDeclaration = [{ + MemDescType cloneWith(std::optional> shape, + Type elementType) const { + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory()); + } + + bool hasRank() const { return true; } + }]; + let builders = [ + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory); + }]> + ]; + let hasCustomAssemblyFormat = 1; +} + +// Result type of MakeTensorDescriptor +def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> { + let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system"; + + let description = [{ + A portable abstraction for nvidia-TMA descriptors. + }]; + + let parameters = (ins "RankedTensorType":$blockType); + let assemblyFormat = "`<` $blockType `>`"; + + let builders = [ + TypeBuilder<(ins "RankedTensorType":$blockType, "bool":$isSigned), [{ + if (auto intTy = llvm::dyn_cast(blockType.getElementType())) { + auto sem = isSigned ? IntegerType::Signed : IntegerType::Unsigned; + auto elemTy = IntegerType::get($_ctxt, intTy.getWidth(), sem); + blockType = RankedTensorType::get(blockType.getShape(), elemTy); + } + return Base::get($_ctxt, blockType); + }]>, + ]; + let extraClassDeclaration = [{ + RankedTensorType getSignlessBlockType() const { + auto resTy = getBlockType(); + if (auto intTy = llvm::dyn_cast(resTy.getElementType())) { + auto width = resTy.getElementTypeBitWidth(); + auto signlessTy = IntegerType::get(getContext(), width); + resTy = RankedTensorType::get(resTy.getShape(), signlessTy); + } + return resTy; + } + }]; +} + +#endif diff --git a/third_party/ascend/triton_patch/lib/CMakeLists.txt b/third_party/ascend/triton_patch/lib/CMakeLists.txt new file mode 100644 index 000000000..0ca0f41c5 --- /dev/null +++ b/third_party/ascend/triton_patch/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..5e601271e --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Triton) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..3b7c3746a --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(Patched_TritonIR + Dialect.cpp + Ops.cpp + Traits.cpp + Types.cpp + + DEPENDS + TritonTableGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp new file mode 100644 index 000000000..1d9c86f4d --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp @@ -0,0 +1,140 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/DialectImplementation.h" + +#include "mlir/Transforms/InliningUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" +#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; + +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TritonInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + auto funcOp = dyn_cast(callable); + if (!funcOp) + return true; + if (funcOp->hasAttr("noinline")) + return !funcOp->getAttrOfType("noinline").getValue(); + return true; + } + + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + IRMapping &) const final { + return true; + } + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create(op->getLoc(), newDest, + returnOp.getOperands()); + op->erase(); + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only return needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } +}; + +struct TensorModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getRank(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementTypeBitWidth(); + } +}; + +struct MemDescModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getShape().size(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementType().getIntOrFloatBitWidth(); + } +}; + +} // namespace + +void TritonDialect::initialize() { + registerTypes(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + >(); + + // We can also add interface here. + addInterfaces(); + + RankedTensorType::attachInterface(*getContext()); + MemDescType::attachInterface(*getContext()); +} + +Operation *TritonDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp new file mode 100644 index 000000000..2ab26afa9 --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp @@ -0,0 +1,1183 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" + +namespace mlir { +namespace triton { + +void LoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), + triton::GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace triton +} // namespace mlir + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { + +//-- LoadOp -- +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + auto paddingAttr = + padding.has_value() + ? PaddingOptionAttr::get(builder.getContext(), padding.value()) + : PaddingOptionAttr(); + LoadOp::build(builder, state, ptr, mask, other, + builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, + evict, isVolatile); +} + +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { + CanonicalizeMaskedLoadPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return success(); + } +}; + +void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- StoreOp -- +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + /*boundaryCheck=*/{}, cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, Value mask, CacheModifier cache, + EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, + cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, ArrayRef boundaryCheck, + CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + builder.getDenseI32ArrayAttr(boundaryCheck), cache, + evict); +} + +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern : public OpRewritePattern { + CanonicalizeMaskedStorePattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return success(); + } +}; + +void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- TransOp -- +OpFoldResult TransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + return getSrc(); + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + return {}; +} + +LogicalResult TransOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input + auto argTy = cast(operands[0].getType()); + auto order = properties.as()->order.asArrayRef(); + SmallVector retShape = applyPermutation(argTy.getShape(), order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, order, retEncoding) + .failed()) { + return failure(); + } + } + if (auto memDescTy = dyn_cast(argTy)) { + inferredReturnTypes.push_back(MemDescType::get( + retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), + memDescTy.getMutableMemory())); + } else { + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +//-- SortOp -- +LogicalResult SortOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) + { + if (operands.size() != 1) { + return emitOptionalError(location, "expected exactly one operand for SortOp"); + } + + if (!isa(operands[0].getType())) { + return emitOptionalError(location, "operand must be a ranked tensor type for SortOp"); + } + + Value src = operands[0]; + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + auto srcEnc = srcTy.getEncoding(); + + if (srcShape.empty()) { + return emitOptionalError(location, "input tensor must have rank >= 1"); + } + + Type sortedTy = RankedTensorType::get(srcShape, srcTy.getElementType(), srcEnc); + + inferredReturnTypes.push_back(sortedTy); + + return success(); +} + +LogicalResult TransOp::verify() { + // Check that the op's `order` attribute is a permutation of the right length. + auto srcTy = getSrc().getType(); + + ArrayRef order = getOrder(); + if (order.size() != srcTy.getRank()) { + return emitError("order must have the same size as the rank of the " + "operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +//-- DotOp -- +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc && retEnc); + Dialect &dialect = retEnc.getDialect(); + auto interface = dyn_cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult DotOp::verify() { + auto aTy = getA().getType(); + auto bTy = getB().getType(); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding && !bEncoding) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + auto accTy = getC().getType(); + auto retEnc = accTy.getEncoding(); + if (!retEnc) + return emitError("miss encoding of C operand"); + Dialect &dialect = retEnc.getDialect(); + auto interface = cast(&dialect); + return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, + bEncoding); +} + +//-- MakeRangeOp -- +OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { + // make_range(start, start + 1) -> constant(start) + if (adaptor.getStart() + 1 == adaptor.getEnd()) { + auto shapedType = cast(getType()); + return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); + } + return {}; +} + +LogicalResult MakeRangeOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start > end) { + return this->emitOpError() << "start must be less than or equal to end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start != ty.getShape()[0]) { + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must match size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32)) { + return this->emitOpError() << "returned tensor must have i32 elements"; + } + return success(); +} + +//-- ReduceOp -- +static LogicalResult +inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis, + SmallVectorImpl &inferredReturnTypes) { + auto retShape = argTy.getShape().vec(); + retShape.erase(retShape.begin() + axis); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(retEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = + dyn_cast(&dialect); + if (inferLayoutInterface + ->inferReduceOpEncoding(argEncoding, axis, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return failure(); + } + } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +void ReduceOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis) { + SmallVector inferredReturnTypes; + for (unsigned i = 0; i < operands.size(); ++i) { + auto argTy = cast(operands[i].getType()); + auto retEltTy = argTy.getElementType(); + (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + } + + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + +LogicalResult ReduceOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + for (auto arg : operands) { + auto argTy = cast(arg.getType()); + auto retEltTy = argTy.getElementType(); + if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +// Helpers for Reductions and Scans +template LogicalResult verifyReduceScan(Op &op) { + if (op.getOperands().empty()) { + return op.emitOpError() << "must have at least 1 operand"; + } + if (op.getNumOperands() != op.getNumResults()) { + return op.emitOpError() << "must have the same number of inputs as outputs"; + } + + auto getElementType = [](Type ty) { + if (auto tensorType = dyn_cast(ty)) { + return tensorType.getElementType(); + } + return ty; + }; + + for (auto [opElemTy, resTy] : + llvm::zip(op.getElementTypes(), op.getResultTypes())) { + if (opElemTy != getElementType(resTy)) { + return op.emitOpError() << "operand types and result types must agree"; + } + } + return success(); +} + +template +static LogicalResult verifyRegionsImpl(Op &op) { + auto argElementTypes = op.getElementTypes(); + const auto &operands = op.getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *op.getBody(); + if (block.getNumArguments() != numArgs) { + return op.emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; + if (blockArgTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return op.emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return op.emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return success(); +} + +static llvm::SmallVector +getInputTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcTys; + srcTys.reserve(operands.size()); + for (const auto &ty : operands.getTypes()) { + srcTys.push_back(cast(ty)); + } + return srcTys; +} + +static llvm::SmallVector +getElementTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(operands.size()); + for (const auto &op : operands) { + srcElemTys.push_back(cast(op.getType()).getElementType()); + } + return srcElemTys; +} + +LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ReduceOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ReduceOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + +//-- ScanOp -- +void ScanOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis, bool reverse) { + SmallVector inferredReturnTypes; + state.addAttribute("reverse", builder.getBoolAttr(reverse)); + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + +LogicalResult +ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + return success(); +} + +LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ScanOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ScanOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ScanOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } + +//-- SplatOp -- +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getSrc(); + if (!value) + return {}; + if (!isa(value)) + return {}; + auto shapedType = cast(getType()); + auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); + return ret; +} + +//-- ExpandDimsOp -- +LogicalResult ExpandDimsOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + auto arg = operands[0]; + auto argTy = cast(arg.getType()); + auto retShape = argTy.getShape().vec(); + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + retShape.insert(retShape.begin() + axis, 1); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc) + .failed()) + return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return success(); +} + +LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + // expand_dims(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) + // + // On its own this doesn't do much, but consider + // broadcast(expand_dims(broadcast)) + // -> broadcast(broadcast(expand_dims)) + // -> broadcast(expand_dims) + if (auto broadcast = dyn_cast(definingOp)) { + auto src = broadcast.getSrc(); + auto srcTy = src.getType(); + SmallVector newExpandShape(srcTy.getShape()); + newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); + + // Infer the encoding of the new expand op, if encodings are present. + Attribute newExpandEnc; + if (auto srcEnc = srcTy.getEncoding()) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc, + op.getLoc()) + .failed()) { + return emitOptionalError(op.getLoc(), + "failed to infer layout for ExpandDimsOp"); + } + } + + auto newExpandTy = RankedTensorType::get( + newExpandShape, srcTy.getElementType(), newExpandEnc); + auto newExpand = rewriter.create(op.getLoc(), newExpandTy, + src, op.getAxis()); + auto newBroadcast = rewriter.create( + broadcast.getLoc(), op.getType(), newExpand.getResult()); + rewriter.replaceOp(op, {newBroadcast.getResult()}); + return success(); + } + + return failure(); +} + +template +static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { + if (!value) + return {}; + + auto shapedType = cast(op.getType()); + if (auto denseElemsAttr = dyn_cast(value)) { + if (denseElemsAttr.isSplat()) { + return denseElemsAttr.resizeSplat(shapedType); + } else { + return denseElemsAttr.reshape(shapedType); + } + } + return {}; +} + +OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +//-- ReshapeOp -- +template +LogicalResult canonicalizeViewOrBroadcast(OpType op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + + // view(view) -> view + if (auto parentView = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, TypeRange({op.getType()}), + parentView->getOperands(), + parentView->getAttrs()); + return success(); + } + + // view(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + + return failure(); +} + +LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { + if (!op.getAllowReorder() || op.getEfficientLayout()) + return failure(); + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +LogicalResult ReshapeOp::verify() { + auto dstTy = getType(); + auto srcTy = getSrc().getType(); + if (getType().getNumElements() != srcTy.getNumElements()) { + return emitError( + "number of src and dst elements of reshape must be the same"); + } + + Attribute srcEnc = srcTy.getEncoding(); + Attribute dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("Op requires that either (a) src and dst both have " + "encodings, or (b) neither does."); + } + + if (srcEnc && !getAllowReorder()) { + Attribute inferredDstEnc; + if (cast(&srcEnc.getDialect()) + ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, + dstTy.getShape(), inferredDstEnc, + getLoc()) + .failed()) { + return emitError("This reshape is impossible without reordering, but " + "reordering is not allowed. Try choosing a different " + "encoding for the input tensor (or allow reordering)."); + } + if (inferredDstEnc != dstEnc) { + return emitError("Expected result encoding ") + << inferredDstEnc << " but was " << dstEnc; + } + } + + return success(); +} + +//-- FpToFpOp -- +LogicalResult FpToFpOp::verify() { + auto dstType = getType().getElementType(); + auto srcType = getSrc().getType().getElementType(); + if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && + (!getRounding().has_value())) { + return emitError("Rounding mode is required for FP downcast"); + } + return success(); +} + +//-- BroadcastOp -- +LogicalResult BroadcastOp::canonicalize(BroadcastOp op, + PatternRewriter &rewriter) { + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + auto value = adaptor.getSrc(); + if (!value) + return {}; + + if (auto denseElemsAttr = dyn_cast(value)) { + auto shapedType = cast(getType()); + return denseElemsAttr.resizeSplat(shapedType); + } + return {}; +} + +LogicalResult BroadcastOp::verify() { + auto src = getSrc(); + auto srcTensorType = cast(src.getType()); + auto srcShape = srcTensorType.getShape(); + auto result = getResult(); + auto resultTensorType = cast(result.getType()); + auto resultShape = resultTensorType.getShape(); + if (srcShape.size() != resultShape.size()) { + return emitError("rank of source must be same as rank of result"); + } + for (int i = 0; i < srcShape.size(); i++) { + if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) { + return emitError("Different dimensions at index ") + << i << " between source and result. " + << "Broadcast requires the source dimension to be 1."; + } + } + return success(); +} + +//-- MakeTensorPtrOp -- +void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ValueRange offsets, ArrayRef tensorShape, + ArrayRef order) +{ + // Get pointer type from `base` + auto pointerType = cast(base.getType()); + assert(pointerType != nullptr); + + // Build type `tt.ptr>` + auto tensorType = RankedTensorType::get( + SmallVector(tensorShape.begin(), tensorShape.end()), + pointerType.getPointeeType()); + auto result = PointerType::get(tensorType, 1); + + return build(builder, state, result, base, shape, strides, offsets, + builder.getDenseI32ArrayAttr(order)); +} + +//-- AdvanceOp -- +OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { + // advance(ptr, 0, 0) -> ptr + SmallVector rawOffsets = getOffsets(); + auto offsets = getConstantIntValues(rawOffsets); + if (!offsets.has_value()) + return {}; + for (int64_t offset : offsets.value()) + if (offset != 0) + return {}; + return getPtr(); +} + +//-- MakeTensorDescOp -- +void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ArrayRef blockShape, + bool isSignedInteger) +{ + auto ptrTy = dyn_cast(base.getType()); + if (!ptrTy) { + llvm::report_fatal_error("Expected pointer type"); + } + auto elemTy = ptrTy.getPointeeType(); + SmallVector blockShape64(blockShape); + auto blockTy = RankedTensorType::get(blockShape64, elemTy); + auto descTy = + TensorDescType::get(builder.getContext(), blockTy, isSignedInteger); + return build(builder, state, descTy, base, shape, strides); +} + +// -- DescriptorLoadOp -- +static LogicalResult verifyDescriptorLoadStoreType(Operation *op, + TensorDescType desc, + RankedTensorType tensor) +{ + RankedTensorType block = desc.getSignlessBlockType(); + ArrayRef blockShape = block.getShape(); + ArrayRef tensorShape = tensor.getShape(); + if (blockShape.size() > tensorShape.size()) { + // Allow ranked reduced load if the leading dimensions are all 1s. + for (int i = 0; i < blockShape.size() - tensorShape.size(); ++i) { + if (blockShape[i] != 1) + return op->emitOpError( + "ranked reduce load only allowed for unit dimension leading dim."); + } + blockShape = blockShape.take_back(tensorShape.size()); + } + + if (blockShape == tensorShape && + block.getElementType() == tensor.getElementType()) { + return success(); + } + return op->emitOpError("tensor descriptor block and tensor types must match"); +} + +LogicalResult DescriptorLoadOp::verify() +{ + return verifyDescriptorLoadStoreType(*this, getDesc().getType(), getType()); +} + +// -- DescriptorStoreOp -- +LogicalResult DescriptorStoreOp::verify() +{ + return verifyDescriptorLoadStoreType(*this, getDesc().getType(), + getSrc().getType()); +} + +// The following ops, including `call`, `func`, and `return` are copied and +// modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp +// We could revert it back once MLIR has a better inliner interface. +//-- FuncOp -- +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); +#if LLVM_VERSION_MAJOR < 21 + function_interface_impl::addArgAndResultAttrs( +#else // triton_v3.3.x + call_interface_impl::addArgAndResultAttrs( +#endif + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &printer) { + function_interface_impl::printFunctionOp( + printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +// -- CallOp -- +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this).getProperties().callee; + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +// -- ReturnOp -- +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); + + return success(); +} + +// -- JoinOp -- +LogicalResult +JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 2); + assert(operands[0].getType() == operands[1].getType()); + assert(isa(operands[0].getType())); + assert(isa(operands[1].getType())); + + Value lhs = operands[0]; + auto srcTy = cast(lhs.getType()); + + SmallVector retShape(srcTy.getShape()); + retShape.push_back(2); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferJoinOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); + return success(); +} + +// -- SplitOp -- +LogicalResult SplitOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 1); + assert(isa(operands[0].getType())); + + Value src = operands[0]; + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + + if (srcShape.empty() || srcShape.back() != 2) { + return emitOptionalError(location, + "last dimension of input tensor must be 2"); + } + ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferSplitOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); + inferredReturnTypes.push_back(retTy); + inferredReturnTypes.push_back(retTy); + return success(); +} + +// -- ElementwiseInlineAsmOp -- +void ElementwiseInlineAsmOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +LogicalResult ElementwiseInlineAsmOp::verify() { + if (getNumOperands() >= 1) { + auto tensorType = dyn_cast(getOperand(0).getType()); + size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; + if (numInputElems % this->getPackedElement() != 0) { + return emitError("number of input elements ") + << numInputElems + << " must be a multiple of the op's packed_element attribute, " + << getPackedElement(); + } + } + return success(); +} + +// -- ExternElementwiseOp -- +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +// -- ExperimentalTensormapCreateOp -- +LogicalResult ExperimentalTensormapCreateOp::verify() { + auto rank = getBoxDim().size(); + if (getGlobalDim().size() != rank) { + return emitError("Rank mismatch for global dim. Got") + << getGlobalDim().size() << " but expected " << rank; + } + if (getGlobalStride().size() + 1 != rank) { + return emitError("Rank mismatch for global stride. Got") + << getGlobalStride().size() << " but expected " << rank - 1; + } + if (getElementStride().size() != rank) { + return emitError("Rank mismatch for element stride. Got") + << getElementStride().size() << " but expected " << rank; + } + return success(); +} + +// -- GatherOp -- +LogicalResult GatherOp::verify() { + RankedTensorType indicesTy = getIndices().getType(); + RankedTensorType srcTy = getSrc().getType(); + RankedTensorType resTy = getResult().getType(); + + if (indicesTy.getShape() != resTy.getShape()) { + return emitOpError("indices and output shapes must match"); + } + if (indicesTy.getEncoding() != resTy.getEncoding()) { + return emitOpError("indices and output encodings must match"); + } + if (srcTy.getElementType() != resTy.getElementType()) { + return emitOpError("input and output element types must match"); + } + if (srcTy.getRank() != indicesTy.getRank()) { + return emitOpError("input and indices ranks must match"); + } + if (getAxis() >= srcTy.getRank()) { + return emitOpError("gather dimension must be less than the input rank"); + } + for (int dim = 0; dim < indicesTy.getRank(); ++dim) { + if (dim == getAxis()) + continue; + if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { + return emitOpError("indices dimension ") + << dim << " must match the corresponding input dimension"; + } + } + + return success(); +} + +LogicalResult GatherOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + GatherOpAdaptor adaptor(operands, attributes, properties, regions); + auto indicesType = cast(adaptor.getIndices().getType()); + auto srcType = cast(adaptor.getSrc().getType()); + + // Shape and encoding of the indices with the element type of the src. + inferredReturnTypes.push_back( + RankedTensorType::get(indicesType.getShape(), srcType.getElementType(), + indicesType.getEncoding())); + return success(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 000000000..b43a9b56c --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,239 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +#include + +#include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +static LogicalResult verifySameEncoding(Type typeA, Type typeB, + bool allowTensorPointerType) { + // TODO(Keren): the allowTensorPointerType argument is a hack to allow. + // The type checking code is kind of a mess with the current design. + auto getEncoding = [=](Type type) -> Attribute { + Attribute ret; + if (auto tensorType = dyn_cast(type)) { + ret = tensorType.getEncoding(); + } + if (!allowTensorPointerType) { + assert(!triton::isTensorPointerType(type)); + } + return ret; + }; + auto encodingA = getEncoding(typeA); + auto encodingB = getEncoding(typeB); + if (!encodingA || !encodingB) + return success(); + return encodingA == encodingB ? success() : failure(); +} + +LogicalResult +OpTrait::impl::verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) + return op->emitOpError() << "requires the same encoding for all operands"; + + return success(); +} + +LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( + Operation *op, bool allowTensorPointerType) { + if (op->getNumOperands() == 0) + return success(); + + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto resultType : op->getResultTypes()) + if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) + return op->emitOpError() + << "requires the same encoding for all operands and results"; + + return verifySameOperandsEncoding(op, allowTensorPointerType); +} + +LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + // if ((numElements & (numElements - 1)) != 0) + // return op->emitError("Number of elements must be power-of-two, but ") + // << *op << " doesn't follow the rule (" << numElements << ")" + // << " elements"; + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + // if ((numElements & (numElements - 1)) != 0) + // return op->emitError("Number of elements must be power-of-two, but ") + // << *op << " doesn't follow the rule (" << numElements << ")" + // << " elements"; + } + } + return success(); +} + +// Check that the Triton layouts on op's operands and return types are valid. +// For example, we check that the number of warps per block in a Triton GPU +// blocked layout matches that of its module. +// +// It's a little weird to check these properties of a layout only when the +// layout is used in an op, since most of the properties don't actually depend +// on the op. They do depend on the *module*, though, and a layout is attached +// to a module only by virtue of being used in one of the module's ops. +LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { + auto module = op->getParentOfType(); + auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { + // Only ranked tensors can have layouts. + auto rankedTy = dyn_cast(val.getType()); + if (!rankedTy) + return success(); + + mlir::Attribute layout = rankedTy.getEncoding(); + if (!layout) + return success(); + + if (isa(layout)) + return makeErr() << "Shared layout is not allowed on tensor type."; + // TODO(jlebar): Currently this only checks blocked layouts, but other + // layouts also have invariants! + + // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. + if (auto blocked = dyn_cast(layout)) { + // A different verifier should have checked that the layout itself is + // valid, including that threads-per-warp has the same rank as + // warps-per-block etc. + auto layoutRank = blocked.getThreadsPerWarp().size(); + if (layoutRank != rankedTy.getRank()) { + return makeErr() << layout << ".\nLayout has rank " << layoutRank + << ", but the tensor it's attached to has rank " + << rankedTy.getRank() << "."; + } + + int moduleThreadsPerWarp = + ttg::TritonGPUDialect::getThreadsPerWarp(module); + int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); + if (layoutThreadsPerWarp != moduleThreadsPerWarp) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutThreadsPerWarp + << " threads per warp, but the module specifies " + << moduleThreadsPerWarp << " threads per warp."; + } + + int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module); + int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); + if (layoutWarpsPerCTA != moduleWarpsPerCTA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutWarpsPerCTA + << " warps per CTA, but the module specifies " + << moduleWarpsPerCTA << " warps per CTA."; + } + + if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { + int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module); + int64_t layoutCTAsPerCGA = + product(blocked.getCTALayout().getCTAsPerCGA()); + if (layoutCTAsPerCGA != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutCTAsPerCGA + << " CTAs per CGA, but the module specifies " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + } + } + + return success(); + }; + + for (size_t i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + auto err = checkLayout(operand, [&]() { + // Stringify the operand using `printAsOperand`. This prints e.g. "%42" + // rather than the full definition. + std::string operandStr; + llvm::raw_string_ostream os(operandStr); + // If we don't assume verified, dump() will recursively call this + // function! + operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); + + return op->emitError("Operand ") + << i << " (" << operand << ") has an invalid layout: "; + }); + if (!err.succeeded()) + return err; + } + + for (size_t i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + auto err = checkLayout(result, [&]() { + if (op->getNumResults() == 1) { + return op->emitError("Result has an invalid layout: "); + } else { + return op->emitError("Result ") << i << " has an invalid layout: "; + } + }); + if (!err.succeeded()) + return err; + } + + return success(); +} + +static ArrayRef getTypeShape(Type type) { + auto rankedType = dyn_cast(type); + if (auto ptrType = dyn_cast(type)) + rankedType = dyn_cast(ptrType.getPointeeType()); + return rankedType ? rankedType.getShape() : ArrayRef(); +} + +LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() << "requires the same shape for all operands"; + + return success(); +} + +LogicalResult +OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : op->getResultTypes()) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + + return verifySameLoadStoreOperandsShape(op); +} diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp new file mode 100644 index 000000000..6e41e70a8 --- /dev/null +++ b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp @@ -0,0 +1,197 @@ +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void TritonDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + >(); +} + +Type PointerType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + int addressSpace = 1; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseInteger(addressSpace)) + return Type(); + } + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType, addressSpace); +} + +void PointerType::print(AsmPrinter &printer) const { + if (getAddressSpace() == 1) { + printer << "<" << getPointeeType() << ">"; + } else { + printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; + } +} + +static constexpr llvm::StringRef kMutableMemory = "mutable"; + +Type MemDescType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false)) + return Type(); + + // Parse the element type. + Type elementType; + if (parser.parseType(elementType)) + return Type(); + + Attribute encoding; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(encoding)) + return Type(); + } + bool mutableMemory = false; + Attribute memorySpace; + if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseOptionalKeyword(kMutableMemory))) { + if (parser.parseAttribute(memorySpace)) + return Type(); + } else { + mutableMemory = true; + } + } + if (mutableMemory == false && succeeded(parser.parseOptionalComma())) { + if (parser.parseOptionalKeyword(kMutableMemory)) + return Type(); + mutableMemory = true; + } + if (parser.parseGreater()) + return Type(); + return MemDescType::get(parser.getContext(), dimensions, elementType, + encoding, memorySpace, mutableMemory); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + for (auto dim : getShape()) + printer << dim << "x"; + printer << getElementType(); + if (getEncoding()) + printer << ", " << getEncoding(); + if (getMemorySpace()) + printer << ", " << getMemorySpace(); + if (getMutableMemory()) + printer << ", " << kMutableMemory; + printer << ">"; +} + +namespace mlir { + +namespace triton { + +unsigned getPointeeBitWidth(Type type) { + auto pointeeType = getPointeeType(type); + if (auto tensorTy = dyn_cast(pointeeType)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + return pointeeType.getIntOrFloatBitWidth(); +} + +Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i1Type, + tensorTy.getEncoding()); + return i1Type; +} + +Type getPointeeType(Type type) { + if (auto tensorTy = dyn_cast(type)) { + // Tensor of pointers + auto shape = tensorTy.getShape(); + auto ptrType = dyn_cast(tensorTy.getElementType()); + Type pointeeType = ptrType.getPointeeType(); + return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return type; +} + +Type getI32SameShape(Type type) { + auto i32Type = IntegerType::get(type.getContext(), 32); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i32Type, + tensorTy.getEncoding()); + return i32Type; +} + +Type getPointerTypeSameShape(Type type) { + if (auto tensorTy = dyn_cast(type)) { + Type elementType = tensorTy.getElementType(); + auto shape = tensorTy.getShape(); + PointerType ptrType = PointerType::get(elementType, 1); + return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding()); + } else { + return PointerType::get(type, 1); + } +} + +Type getPointerTypeToElement(Type type) { + Type elementType = getElementTypeOrSelf(type); + PointerType ptrType = PointerType::get(elementType, 1); + return ptrType; +} + +// upstream Triton only uses address space 1 for Pointer Type +Type getPointerType(Type type, int addressSpace) { + return PointerType::get(type, addressSpace); +} + +int getAddressSpace(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getAddressSpace(); + return 1; +} + +bool isTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + return isa(ptrType.getPointeeType()); + return false; +} + +bool isTensorOrTensorPointerType(Type type) { + return isa(type) || isTensorPointerType(type); +} + +Type getElementTypeOfTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) + return tensorTy.getElementType(); + return {}; +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/ascend/triton_patch/lib/runtime/libentry/libentry.cpp b/third_party/ascend/triton_patch/lib/runtime/libentry/libentry.cpp new file mode 100644 index 000000000..7374fec74 --- /dev/null +++ b/third_party/ascend/triton_patch/lib/runtime/libentry/libentry.cpp @@ -0,0 +1,102 @@ +#include "runtime/libentry/libentry.h" + +using namespace libentry; + +void libentry::ArgProcessor::classifyArguments( + const py::list& args, + const py::dict& kwargs, + const py::list& jit_params, + const std::unordered_set& specialize_indices, + const std::unordered_set& do_not_specialize_indices) +{ + for (size_t i = 0; i < args.size(); ++i) { + if (specialize_indices.count(i)) { + k_args_.append(args[i]); + spec_args_.append(args[i]); + } else if (do_not_specialize_indices.count(i)) { + k_args_.append(args[i]); + dns_args_.append(args[i]); + } else { + const_args_.append(args[i]); + } + } + + for (size_t i = args.size(); i < jit_params.size(); ++i) { + const py::object& param = jit_params[i]; + py::object val; + + if (kwargs.contains(param.attr("name"))) { + val = kwargs[param.attr("name")]; + } else if (py::hasattr(param, "default") && !param.attr("default").is_none()) { + val = param.attr("default"); + } else { + continue; + } + + if (param.attr("is_constexpr").cast()) { + const_args_.append(val); + } else if (param.attr("do_not_specialize").cast()) { + dns_args_.append(val); + k_args_.append(val); + } else { + spec_args_.append(val); + k_args_.append(val); + } + } +} + +KeyType libentry::ArgProcessor::generateKey() +{ + auto is_tensor = [](py::handle x) { + return py::hasattr(x, "data_ptr"); + }; + auto is_int = [](py::handle x) { + return py::isinstance(x); + }; + + py::list spec_key; + for (auto arg : spec_args_) { + if (is_tensor(arg)) { + auto dtype = arg.attr("dtype"); + uintptr_t data_ptr = arg.attr("data_ptr")().cast(); + bool aligned = (data_ptr & (divisibility_ - 1)) == 0; + spec_key.append(py::make_tuple(dtype, aligned)); + } else { + spec_key.append(py::make_tuple(py::type::of(arg), arg)); + } + } + + py::list dns_key; + for (auto arg : dns_args_) { + if (is_tensor(arg)) { + dns_key.append(arg.attr("dtype")); + } else if (!is_int(arg)) { + dns_key.append(py::type::of(arg)); + } else { + int64_t val = arg.cast(); + if (val >= -0x80000000LL && val <= 0x7FFFFFFFLL) { + dns_key.append(py::str("i32")); + } else if (val >= 0 && val <= 0xFFFFFFFFFFFFFFFFLL) { + dns_key.append(py::str("u64")); + } else { + dns_key.append(py::str("i64")); + } + } + } + + py::list result; + auto list_append = [&](const py::list& src) { + for (auto handle : src) { + result.append(handle); + } + }; + list_append(spec_key); + list_append(dns_key); + list_append(const_args_); + return result; +} + +py::list libentry::ArgProcessor::getKArgs() +{ + return k_args_; +} \ No newline at end of file diff --git a/third_party/ascend/triton_patch/python/src/ir.cc b/third_party/ascend/triton_patch/python/src/ir.cc new file mode 100644 index 000000000..a2fd115a3 --- /dev/null +++ b/third_party/ascend/triton_patch/python/src/ir.cc @@ -0,0 +1,1842 @@ +#include +#include +#include +#include + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/LocationSnapshot.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/SourceMgr.h" + +namespace { + +namespace py = pybind11; +using namespace mlir; +using namespace triton; + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + OpBuilder &getBuilder() { return *builder; } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(FileLineColLoc::get(context, fileName, line, column)); + } + + Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +}; + +std::string locationToString(Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +void outputWarning(Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), + /*stack_level=*/2); +} + +} // anonymous namespace + +/*****************************************************************************/ +/* Python bindings for ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::enum_(m, "PADDING_OPTION", py::module_local()) + .value("PAD_ZERO", PaddingOption::PAD_ZERO) + .value("PAD_NAN", PaddingOption::PAD_NAN) + .export_values(); + + py::enum_(m, "CACHE_MODIFIER", py::module_local()) + .value("NONE", CacheModifier::NONE) + .value("CA", CacheModifier::CA) + .value("CG", CacheModifier::CG) + .value("WB", CacheModifier::WB) + .value("CS", CacheModifier::CS) + .value("WT", CacheModifier::WT) + .value("CV", CacheModifier::CV) + .export_values(); + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", MemSyncScope::GPU) + .value("CTA", MemSyncScope::CTA) + .value("SYSTEM", MemSyncScope::SYSTEM) + .export_values(); + + py::enum_(m, "EVICTION_POLICY", py::module_local()) + .value("NORMAL", EvictionPolicy::NORMAL) + .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) + .export_values(); + + py::enum_(m, "ATOMIC_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX); + + py::enum_(m, "ROUNDING_MODE", py::module_local()) + .value("RTZ", RoundingMode::RTZ) + .value("RTNE", RoundingMode::RTNE); + + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", PropagateNan::NONE) + .value("ALL", PropagateNan::ALL); + + py::enum_(m, "INPUT_PRECISION", py::module_local()) + .value("TF32", InputPrecision::TF32) + .value("TF32x3", InputPrecision::TF32x3) + .value("IEEE", InputPrecision::IEEE) + .value("HF32", InputPrecision::HF32) + .export_values(); + + py::enum_(m, "F8F6F4TY", py::module_local()) + .value("E4M3", F8F6F4Type::E4M3) + .value("E5M2", F8F6F4Type::E5M2) + .value("E2M3", F8F6F4Type::E2M3) + .value("E3M2", F8F6F4Type::E3M2) + .value("E2M1", F8F6F4Type::E2M1) + .value("BF16", F8F6F4Type::BF16) + .value("FP16", F8F6F4Type::FP16) + .export_values(); + + py::class_(m, "context", py::module_local()) + .def(py::init<>()) + .def("printOpOnDiagnostic", + [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) + .def("printStackTraceOnDiagnostic", + [](MLIRContext &self, bool v) { + self.printStackTraceOnDiagnostic(v); + }) + .def("disable_multithreading", + [](MLIRContext &self) { self.disableMultithreading(); }); + + py::class_(m, "source_mgr_diag", + py::module_local()) + .def(py::init()); + + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + mlir::LLVM::registerInlinerInterface(registry); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + mlir::LLVM::registerInlinerInterface(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "type", py::module_local()) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) + .def("is_fp16", &Type::isF16) + .def("__str__", [](Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type", py::module_local()) + .def("param_types", [](FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); + + py::class_(m, "location", py::module_local()) + .def("__str__", [](Location &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "value", py::module_local()) + .def("set_attr", + [](Value &self, std::string &name, Attribute &attr) -> void { + if (Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + auto arg = mlir::cast(self); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } + }) + .def("get_context", &Value::getContext) + .def("replace_all_uses_with", + [](Value &self, Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + .def("get_type", &Value::getType) + .def("id", [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }); + + py::class_(m, "op_result", py::module_local()); + + py::class_(m, "block_argument", py::module_local()); + + py::class_(m, "region", py::module_local()) + .def("get_parent_region", &Region::getParentRegion, ret::reference) + .def("size", [](Region &self) { return self.getBlocks().size(); }) + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }); + + py::class_(m, "block", py::module_local()) + .def("arg", + [](Block &self, int index) -> BlockArgument { + if (index >= self.getNumArguments()) + throw pybind11::index_error("Block argument index out of range"); + return self.getArgument(index); + }) + .def("add_argument", + [](Block &self, Type ty) { + auto loc = UnknownLoc::get(ty.getContext()); + self.addArgument(ty, loc); + }) + .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) + .def("dump", &Block::dump) + .def("move_before", + [](Block &self, Block &dst) { self.moveBefore(&dst); }) + .def("insert_before", &Block::insertBefore) + .def("get_parent", &Block::getParent, ret::reference) + .def("merge_block_before", + [](Block &self, Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error( + "This block has arguments, don't merge"); + dst.getOperations().splice(dst.begin(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", + [](Block &self, Value &v, Value &newVal) { + v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = + currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; + }); + }) + .def("__str__", + [](Block &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return str; + }) + .def("has_terminator", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("has_return", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); + + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", py::module_local()); + py::class_(m, "bool_attr", py::module_local()); + py::class_(m, "unit_attr", py::module_local()); + py::class_(m, "str_attr", py::module_local()); + py::class_(m, "array_attr", py::module_local()); + + // Ops + py::class_(m, "OpState", py::module_local()) + .def("set_attr", + [](OpState &self, std::string &name, Attribute &attr) -> void { + self->setAttr(name, attr); + }) + .def("get_num_results", + [](OpState &self) -> unsigned { return self->getNumResults(); }) + .def("get_result", + [](OpState &self, unsigned idx) -> Value { + if (idx >= self->getNumResults()) + throw pybind11::index_error("Op result index out of range"); + return self->getResult(idx); + }) + .def( + "get_region", + [](OpState &self, unsigned idx) -> Region & { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self->getRegion(idx); + }, + ret::reference) + .def( + "get_body", + [](scf::ForOp &self, unsigned idx) -> Block * { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self.getBody(idx); + }, + ret::reference) + .def("dump", [](OpState &self) { self->dump(); }) + .def("__str__", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self->print(os, printingFlags); + return str; + }) + .def("append_operand", + [](OpState &self, Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) + .def("verify", [](OpState &self) -> bool { + return succeeded(verify(self.getOperation())); + }); + // scf Ops + py::class_(m, "ForOp", py::module_local()) + .def("get_induction_var", &scf::ForOp::getInductionVar); + + py::class_(m, "IfOp", py::module_local()) + .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &scf::IfOp::thenYield) + .def("get_else_yield", &scf::IfOp::elseYield); + py::class_(m, "YieldOp", py::module_local()); + py::class_(m, "WhileOp", py::module_local()) + .def("get_before", &scf::WhileOp::getBefore, ret::reference) + .def("get_after", &scf::WhileOp::getAfter, ret::reference); + py::class_(m, "ConditionOp", py::module_local()); + + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_bool_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::bool_(ret.getValue()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + + // dynamic_attr is used to transfer ownership of the MLIR context to the + // module + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) + .def("dump", &ModuleOp::dump) + .def("str", + [](ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self.print(os, printingFlags); + return str; + }) + .def("push_back", + [](ModuleOp &self, FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("has_function", + [](ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", + [](ModuleOp &self, std::string &funcName) -> FuncOp { + return self.lookupSymbol(funcName); + }) + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("create_location_snapshot", + [](ModuleOp &self, const std::string &fileName) -> void { + generateLocationsFromIR(/*raw_ostream=*/llvm::nulls(), + /*fileName=*/fileName, + /*op=*/self, /*flags=*/{}); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); + + m.def("make_attr", [](const std::vector &values, MLIRContext &context) { + return mlir::cast(DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + IntegerType::get(&context, 32)), + values)); + }); + + m.def( + "parse_mlir_module", + [](const std::string &inputFilename, MLIRContext &context) { + // parse module + OwningOpRef module = + parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); + return module->clone(); + }, + ret::take_ownership); + + py::class_(m, "function", py::module_local()) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", + [](FuncOp &self, unsigned idx) -> BlockArgument { + if (idx >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + return self.getArgument(idx); + }) + .def( + "add_entry_block", + [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, + ret::reference) + .def( + "set_arg_attr", + [](FuncOp &self, int arg_no, const std::string &name, int val) { + if (arg_no >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + // set arg attributes "name" to value "val" + auto attrTy = IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); + }, + ret::reference) + // .def("has_attr", &::FuncOp::hasAttr) + .def("finalize", + [](FuncOp &self) -> void { + // Check if the result of tl.advance is used + self.walk([&](AdvanceOp op) { + if (op->getResult(0).use_empty()) + outputWarning(op->getLoc(), "The result of tl.advance is not " + "being used. Note that tl.advance " + "does not have any side effects. " + "To move the block pointer, you " + "need to assign the result of " + "tl.advance to a variable."); + }); + }) + .def_property_readonly("type", &FuncOp::getFunctionType) + .def("reset_type", &FuncOp::setType); + + py::class_(m, "InsertPoint", py::module_local()); + + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + // getters + .def("create_module", + [](TritonOpBuilder &self) -> ModuleOp { + return self.create(); + }) + // insertion block/point + .def("set_insertion_point_to_start", + [](TritonOpBuilder &self, Block &block) -> void { + self.setInsertionPointToStart(block); + }) + .def("set_insertion_point_to_end", + [](TritonOpBuilder &self, Block &block) { + self.setInsertionPointToEnd(block); + }) + .def("set_insertion_point_after", + [](TritonOpBuilder &self, Operation &op) { + self.setInsertionPointAfter(op); + }) + .def( + "get_insertion_block", + [](TritonOpBuilder &self) -> Block * { + return self.getBuilder().getInsertionBlock(); + }, + ret::reference) + .def("get_insertion_point", + [](TritonOpBuilder &self) { + return self.getBuilder().saveInsertionPoint(); + }) + .def("restore_insertion_point", + [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { + self.restoreInsertionPoint(pt); + }) + // Attr + .def("get_bool_attr", + [](TritonOpBuilder &self, bool value) { + return self.getBuilder().getBoolAttr(value); + }) + .def("get_int32_attr", + [](TritonOpBuilder &self, int32_t value) { + return self.getBuilder().getI32IntegerAttr(value); + }) + .def("get_str_attr", + [](TritonOpBuilder &self, std::string value) { + return self.getBuilder().getStringAttr(value); + }) + .def("get_unit_attr", + [](TritonOpBuilder &self) { + return self.getBuilder().getUnitAttr(); + }) + .def("get_i64_array_attr", + [](TritonOpBuilder &self, const std::vector& array) { + return self.getBuilder().getI64ArrayAttr(array); + }) + // Use arith.ConstantOp to create constants + // Constants + .def("get_int1", + [](TritonOpBuilder &self, bool v) -> Value { + return Value(self.create( + v, self.getBuilder().getI1Type())); + }) + .def("get_int8", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_int16", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_int32", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_int64", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_uint8", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_uint16", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_uint32", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_uint64", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_bf16", + [](TritonOpBuilder &self, float v) -> Value { + auto type = self.getBuilder().getBF16Type(); + return self.create( + APFloat(type.getFloatSemantics(), std::to_string(v)), type); + }) + .def("get_fp16", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF16FloatAttr(v)); + }) + .def("get_fp32", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF32FloatAttr(v)); + }) + .def("get_fp64", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + self.getBuilder().getF64FloatAttr(v)); + }) + .def("get_null_value", + [](TritonOpBuilder &self, Type type) -> Value { + if (auto floatTy = dyn_cast(type)) + return self.create( + APFloat(floatTy.getFloatSemantics(), 0), floatTy); + else if (auto intTy = dyn_cast(type)) + return self.create(0, intTy); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](TritonOpBuilder &self, Type type) -> Value { + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = dyn_cast(type)) + return self.create(val, intTy); + else + throw std::runtime_error("Not implemented"); + }) + + // Types + .def("get_void_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getNoneType(); + }) + .def("get_int1_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI1Type(); + }) // or ret::copy? + .def("get_int8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_int16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(16); + }) + .def("get_int32_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI32Type(); + }) + .def("get_int64_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI64Type(); + }) + .def("get_fp8e4nv_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b15_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_fp8e5_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_half_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF16Type(); + }) + .def("get_bf16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getBF16Type(); + }) + .def("get_float_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF32Type(); + }) + .def("get_double_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF64Type(); + }) + .def("get_ptr_ty", + [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { + return PointerType::get(type, addrSpace); + }) + .def("get_block_ty", + [](TritonOpBuilder &self, Type &elementType, + std::vector &shape) -> Type { + return RankedTensorType::get(shape, elementType); + }) + .def("get_function_ty", + [](TritonOpBuilder &self, std::vector inTypes, + std::vector outTypes) -> Type { + return self.getBuilder().getFunctionType(inTypes, outTypes); + }) + // locs + .def("set_loc", + [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) + .def("set_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) { self.setLastLoc(fileName, line, column); }) + .def("get_loc", + [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) + + // Ops + .def("get_or_insert_function", + [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, + Type &funcType, std::string &visibility, + bool noinline) -> FuncOp { + if (Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + if (auto funcTy = dyn_cast(funcType)) { + llvm::SmallVector attrs = { + NamedAttribute( + self.getBuilder().getStringAttr("sym_visibility"), + self.getBuilder().getStringAttr(visibility)), + NamedAttribute(self.getBuilder().getStringAttr("noinline"), + self.getBuilder().getBoolAttr(noinline))}; + return self.create(funcName, funcTy, attrs); + } + throw std::invalid_argument("invalid function type"); + }) + .def( + "create_block", + [](TritonOpBuilder &self) -> Block * { + Region *parent = self.getBuilder().getBlock()->getParent(); + return self.getBuilder().createBlock(parent); + }, + ret::reference) + .def( + "create_block_with_parent", + [](TritonOpBuilder &self, Region &parent, + std::vector &argTypes) -> Block * { + // TODO: update arg loc + auto loc = self.getBuilder().getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), loc); + return self.getBuilder().createBlock(&parent, {}, argTypes, + argLocs); + }, + ret::reference) + .def( + "new_block", + [](TritonOpBuilder &self) -> Block * { return new Block(); }, + ret::reference) + // Function + .def("ret", + [](TritonOpBuilder &self, std::vector &vals) -> OpState { + return self.create(vals); + }) + .def("call", + [](TritonOpBuilder &self, FuncOp &func, std::vector &args) + -> OpState { return self.create(func, args); }) + // Unstructured control flow + .def("create_cond_branch", + [](TritonOpBuilder &self, Value condition, Block *trueDest, + Block *falseDest) -> OpState { + return self.create(condition, trueDest, + falseDest); + }) + .def("create_branch", + [](TritonOpBuilder &self, Block *dest, std::vector &args) + -> OpState { return self.create(dest, args); }) + // Structured control flow + .def("create_for_op", + [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, + std::vector &initArgs) -> scf::ForOp { + return self.create(lb, ub, step, initArgs); + }) + .def("create_if_op", + [](TritonOpBuilder &self, std::vector &retTypes, + Value &condition, bool withElse) -> scf::IfOp { + return self.create(retTypes, condition, withElse); + }) + .def("create_yield_op", + [](TritonOpBuilder &self, std::vector &yields) + -> scf::YieldOp { return self.create(yields); }) + .def("create_while_op", + [](TritonOpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> scf::WhileOp { + return self.create(retTypes, initArgs); + }) + .def("create_condition_op", + [](TritonOpBuilder &self, Value &cond, + std::vector &args) -> scf::ConditionOp { + return self.create(cond, args); + }) + + // miscellaneous + .def("create_make_range", + [](TritonOpBuilder &self, int start, int end) -> Value { + auto retType = RankedTensorType::get( + {end - start}, self.getBuilder().getI32Type()); + return self.create(retType, start, end); + }) + + // Cast instructions + // Conversions for custom FP types (FP8 and non-standard rounding modes) + .def("create_fp_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType, + std::optional roundingMode) -> Value { + if (roundingMode.has_value()) + return self.create( + dstType, src, + RoundingModeAttr::get(self.getBuilder().getContext(), + roundingMode.value())); + else + return self.create(dstType, src); + }) + // Conversions for standard LLVM builtin types + .def("create_bitcast", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_si_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_ui_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_si", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_ui", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_ext", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_trunc", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_int_cast", + [](TritonOpBuilder &self, Value &src, Type &dstType, + bool isSigned) -> Value { + // get element type if necessary + Type srcType = src.getType(); + auto srcTensorType = dyn_cast(srcType); + auto dstTensorType = dyn_cast(dstType); + Type srcEltType = srcType; + Type dstEltType = dstType; + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); + } + unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); + unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); + if (srcWidth == dstWidth) + return self.create(dstType, src); + else if (srcWidth > dstWidth) + return self.create(dstType, src); + else if (isSigned) + return self.create(dstType, src); + else + return self.create(dstType, src); + }) + .def("create_to_index", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getIndexType(), input); + }) + .def("create_index_to_si", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getI64Type(), input); + }) + .def("create_fmul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_frem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fadd", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fsub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_umulhi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_udiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_srem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_urem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_fma", + [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { + return Value(self.create(a, b, c)); + }) + .def("create_shl", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_lshr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_ashr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minimumf follows the torch.minimum convention and returns NaN if either + // operand is NaN + .def("create_minimumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minnumf follows the torch.fmin convention and returns the non-NaN + // operand + .def("create_minnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maximumf follows the torch.maximum convention and returns NaN if either + // operand is NaN + .def("create_maximumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maxnumf follows the torch.fmax convention and returns the non-NaN + // operand + .def("create_maxnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_clampf", + [](TritonOpBuilder &self, Value &input, Value &min, Value &max, + PropagateNan propagateNan) -> Value { + return Value(self.create(input, min, max, propagateNan)); + }) + .def("create_precise_sqrt", + [](TritonOpBuilder &self, Value &input) -> Value { + return Value(self.create(input)); + }) + .def("create_precise_divf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // AddPtr (similar to GEP) + .def("create_addptr", + [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { + return self.create(ptr.getType(), ptr, offset); + }) + // Comparison (int) + .def("create_icmpSLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sle, lhs, + rhs); + }) + .def("create_icmpSLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::slt, lhs, + rhs); + }) + .def("create_icmpSGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sge, lhs, + rhs); + }) + .def("create_icmpSGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sgt, lhs, + rhs); + }) + .def("create_icmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ule, lhs, + rhs); + }) + .def("create_icmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ult, lhs, + rhs); + }) + .def("create_icmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::uge, lhs, + rhs); + }) + .def("create_icmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ugt, lhs, + rhs); + }) + .def("create_icmpEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::eq, lhs, + rhs); + }) + .def("create_icmpNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ne, lhs, + rhs); + }) + // Comparison (float) + .def("create_fcmpOLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLT, lhs, + rhs); + }) + .def("create_fcmpOGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGT, lhs, + rhs); + }) + .def("create_fcmpOLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLE, lhs, + rhs); + }) + .def("create_fcmpOGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGE, lhs, + rhs); + }) + .def("create_fcmpOEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OEQ, lhs, + rhs); + }) + .def("create_fcmpONE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ONE, lhs, + rhs); + }) + .def("create_fcmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULT, lhs, + rhs); + }) + .def("create_fcmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGT, lhs, + rhs); + }) + .def("create_fcmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULE, lhs, + rhs); + }) + .def("create_fcmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGE, lhs, + rhs); + }) + .def("create_fcmpUEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UEQ, lhs, + rhs); + }) + .def("create_fcmpUNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UNE, lhs, + rhs); + }) + // // Logical + .def("create_and", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_xor", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_or", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + // Input/Output + .def("create_load", + [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_store", + [](TritonOpBuilder &self, Value &ptrs, Value &value, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, value, cacheModifier, evictionPolicy); + }) + .def("create_tensor_pointer_load", + [](TritonOpBuilder &self, Value &ptr, + std::vector &boundaryCheck, + std::optional paddingOption, + CacheModifier cacheModifier, EvictionPolicy evictionPolicy, + bool isVolatile) -> Value { + return self.create(ptr, boundaryCheck, paddingOption, + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_tensor_pointer_store", + [](TritonOpBuilder &self, Value &ptr, Value &val, + std::vector &boundaryCheck, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptr, val, boundaryCheck, cacheModifier, + evictionPolicy); + }) + .def("create_masked_load", + [](TritonOpBuilder &self, Value &ptrs, Value &mask, + std::optional &other, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_masked_store", + [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, val, mask, cacheModifier, + evictionPolicy); + }) + .def("create_tensor_descriptor_type", + [](TritonOpBuilder &self, Type blockTy, bool isSigned) -> Type { + auto ctx = self.getBuilder().getContext(); + return triton::TensorDescType::get(ctx, cast(blockTy), isSigned); + }) + .def("create_descriptor_load", + [](TritonOpBuilder &self, Value desc, std::vector &indices, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> Value { + auto descTy = cast(desc.getType()); + auto resTy = descTy.getSignlessBlockType(); + return self.create(resTy, desc, indices, cacheModifier, evictionPolicy); + }) + .def("create_descriptor_store", + [](TritonOpBuilder &self, Value desc, Value value, std::vector &indices) -> void { + self.create(desc, value, indices); + }) + .def("create_tensormap_create", + [](TritonOpBuilder &self, Value desc_ptr, Value global_address, + std::vector box_dim, std::vector global_dim, + std::vector global_stride, + std::vector element_stride, int32_t elem_type, + int32_t interleave_layout, int32_t swizzle_mode, + int32_t fill_mode) { + self.create( + desc_ptr, global_address, box_dim, global_dim, global_stride, + element_stride, elem_type, interleave_layout, swizzle_mode, + fill_mode); + }) + .def("create_tensormap_fenceproxy_acquire", + [](TritonOpBuilder &self, Value desc_ptr) { + self.create(desc_ptr); + }) + .def("create_reshape", + [](TritonOpBuilder &self, Value &arg, std::vector &shape, + bool allowReorder) -> Value { + auto argType = + cast(arg.getType()).getElementType(); + return self.create( + RankedTensorType::get(shape, argType), arg, allowReorder); + }) + .def("create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + std::vector retShape = argType.getShape(); + retShape.insert(retShape.begin() + axis, 1); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, axis); + }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create( + RankedTensorType::get(shape, lhsType.getElementType()), lhs, + rhs); + }) + .def("create_join", + [](TritonOpBuilder &self, Value &a, Value &b) -> Value { + return self.create(a, b); + }) + .def("create_split", + [](TritonOpBuilder &self, Value &a) -> std::vector { + auto op = self.create(a); + return std::vector(op->result_begin(), op->result_end()); + }) + .def("create_extract_scalar", + [](TritonOpBuilder &self, Value &src, std::vector &indices) -> Value { + llvm::SmallVector arg_indices; + for (const auto &i : indices) { + auto iTy = i.getType(); + if (!iTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), i); + arg_indices.push_back(v); + } else { + arg_indices.push_back(i); + } + } + auto ret = self.create(src, arg_indices); + return ret; + }) + .def("create_extract_slice", + [](TritonOpBuilder &self, Value &ful, std::vector &offs_vec, + std::vector &sizs_vec, std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get(retSizes, + cast(ful.getType()).getElementType()); + + return self.create(retTy, ful, offsets, sizes, strides); + }) + .def("create_insert_slice", + [](TritonOpBuilder &self, Value &ful, Value &sub, + std::vector &offs_vec, std::vector &sizs_vec, + std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + auto ret = self.create(sub, ful, offsets, + sizes, strides); + return ret; + }) + // Implements tl.trans and tl.permute. + .def("create_trans", + [](TritonOpBuilder &self, Value &arg, + std::vector &order) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + auto retShape = applyPermutation(argType.getShape(), order); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, order); + }) + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + if (auto argType = dyn_cast(arg.getType())) + return self.createOrFold( + RankedTensorType::get(shape, argType.getElementType()), arg); + throw std::invalid_argument( + "arg is not of RankedTensorType, use create_splat"); + }) + .def("create_splat", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + auto argType = arg.getType(); + auto ret = self.createOrFold( + RankedTensorType::get(shape, argType), arg); + return ret; + }) + // // atomic + .def("create_atomic_cas", + [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, + MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, ptr, cmp, val, sem, + scope); + }) + .def("create_atomic_rmw", + [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, + Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, rmwOp, ptr, val, mask, + sem, scope); + }) + // External + .def("create_extern_elementwise", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, Type retType, bool isPure) -> Value { + return self.create(retType, argList, libName, + libPath, symbol, isPure); + }) + // Built-in instruction + .def("create_get_program_id", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_get_num_programs", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_dot", + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, InputPrecision inputPrecision, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, inputPrecision, + maxNumImpreciseAcc); + }) + .def("create_dot_scaled", + [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale, + F8F6F4Type lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, F8F6F4Type rhs_format, + mlir::Value &c) -> mlir::Value { + return self.create( + c.getType(), lhs, rhs, c, lhs_scale, + rhs_scale.value_or(Value()), lhs_format, rhs_format); + }) + .def("create_floor", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_ceil", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_cos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_erf", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_tanh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_rsqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_fabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_iabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_reduce", + [](TritonOpBuilder &self, std::vector operands, int axis) + -> OpState { return self.create(operands, axis); }) + .def("create_reduce_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_scan", + [](TritonOpBuilder &self, std::vector operands, int axis, + bool reverse) -> OpState { + return self.create(operands, axis, reverse); + }) + .def("create_scan_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_ptr_to_int", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_int_to_ptr", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_select", + [](TritonOpBuilder &self, Value &condition, Value &trueValue, + Value &falseValue) -> Value { + return self.create(condition, trueValue, + falseValue); + }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, const std::vector &values, + const std::vector &types, bool isPure, + int pack) -> OpState { + return self.create( + types, inlineAsm, constraints, isPure, pack, values); + }) + .def("create_print", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values, + const std::vector &isSigned) -> void { + auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)); + self.create(prefixAttr, hex, values, isSigned); + }) + .def("create_assert", + [](TritonOpBuilder &self, Value &condition, + const std::string &message) -> void { + auto messageAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(message)); + self.create(condition, messageAttr); + }) + .def("create_assume", + [](TritonOpBuilder &self, Value &condition) { + self.create(condition); + }) + .def("create_poison", + [](TritonOpBuilder &self, Type &type) -> Value { + return self.create(type); + }) + .def("create_histogram", + [](TritonOpBuilder &self, Value operand, int numBins) -> Value { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand); + }) + .def("create_gather", + [](TritonOpBuilder &self, Value src, Value indices, int axis) + -> Value { return self.create(src, indices, axis); }) + // Force GPU barrier + .def("create_barrier", + [](TritonOpBuilder &self) { self.create(); }) + // Make a block pointer (tensor pointer in Triton IR) + .def("create_make_block_ptr", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &offsets, + std::vector &tensorShape, + std::vector &order) -> Value { + return self.create(base, shape, strides, offsets, + tensorShape, order); + }) + // Advance a block pointer + .def("create_advance", + [](TritonOpBuilder &self, Value &ptr, + std::vector &offsets) -> Value { + return self.create(ptr.getType(), ptr, offsets); + }) + // Add custom op + .def("create_custom_op_for_inter_core_sync", + [](TritonOpBuilder &self, std::string &op_name, + std::string &mode_or_sender, int id) -> void { + auto args = self.getBuilder().getArrayAttr( + {self.getBuilder().getStringAttr(mode_or_sender), + self.getBuilder().getI32IntegerAttr(id)} + ); + self.create(op_name, args, ValueRange()); + }) + // Make a tensor descriptor + .def("create_make_tensor_descriptor", + [](TritonOpBuilder &self, Value &base, std::vector &shape, std::vector &strides, + std::vector &tensorShape, bool isSignedInteger) -> Value { + return self.create(base, shape, strides, tensorShape, isSignedInteger); + }) + // Add an annotation + .def("create_annotation", + [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, + Attribute &attrVal) { + auto annotationOp = self.create(ptr); + annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), + attrVal); + }) + // Add sort + .def("create_sort", + [](TritonOpBuilder &self, Value src, int64_t dim, bool descending) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + auto dimAttr = builder.getI64IntegerAttr(dim); + auto descendingAttr = builder.getBoolAttr(descending); + + auto op = builder.create(loc, src, dimAttr, descendingAttr); + + return op->getResult(0); + }); + + py::class_(m, "pass_manager", py::module_local()) + .def(py::init()) + .def("enable_debug", + [](PassManager &self) { + auto *context = self.getContext(); + bool haveDiagnostics = + ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); + bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + std::string funcToDump; + if (!haveDump) { + funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP"); + if (!funcToDump.empty()) + haveDump = true; + } + if (haveDiagnostics || haveDump) { + context->disableMultithreading(); + } + if (haveDiagnostics) { + context->printOpOnDiagnostic(true); + context->printStackTraceOnDiagnostic(true); + context->getDiagEngine().registerHandler([](Diagnostic &diag) { + llvm::outs() << diag << "\n"; + return success(); + }); + } + if (haveDump) { + auto printingFlags = OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + printingFlags.enableDebugInfo(); + auto printAlways = [funcToDump](Pass *, Operation *op) -> bool { + if (funcToDump.empty()) + return true; + if (auto mod = dyn_cast(op)) { + return mod.lookupSymbol(funcToDump); + } + if (auto func = dyn_cast(op)) { + return SymbolTable::getSymbolName(func).getValue() == + funcToDump; + } + + return false; + }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/printAlways, + /*shouldPrintAfterPass=*/printAlways, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure*/ true, llvm::dbgs(), + printingFlags); + } + }) + .def("run", [](PassManager &self, ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + makeReproducer(anchorName, passes, op, reproducerPath); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector split; + llvm::SmallVector storage; + llvm::SmallVector debugTypes; + + StringRef(debugOnly.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(debugTypes), + [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + + ::llvm::DebugFlag = true; + using namespace llvm; + setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }); +} + +void init_triton_env_vars(py::module &m) { + m.def("get_cache_invalidating_env_vars", + []() -> std::map { + std::map ret; + for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { + auto strVal = triton::tools::getStrEnv(envVar); + if (strVal.empty()) + continue; + auto boolV = triton::tools::isEnvValueBool(strVal); + if (boolV.has_value()) + ret[envVar] = boolV.value() ? "true" : "false"; + else + ret[envVar] = strVal; + } + return ret; + }); +} diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 5a461fb72..fbb951564 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -78,9 +78,17 @@ Type getTypeFromConstraint(char constraint, PatternRewriter &rewriter) { else if (constraint == 'l') ty = IntegerType::get(rewriter.getContext(), 64); else if (constraint == 'f') +#if LLVM_VERSION_MAJOR < 21 ty = FloatType::getF32(rewriter.getContext()); +#else // triton_v3.3.x + ty = Float32Type::get(rewriter.getContext()); +#endif else if (constraint == 'd') +#if LLVM_VERSION_MAJOR < 21 ty = FloatType::getF64(rewriter.getContext()); +#else // triton_v3.3.x + ty = Float64Type::get(rewriter.getContext()); +#endif else { assert(false && "Unsupported constraint"); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index c2940a043..8c731e72b 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -212,17 +212,37 @@ TensorCoreType getMmaType(triton::DotOp op) { return TensorCoreType::FP32_FP16_FP16_FP32; if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; +#if LLVM_VERSION_MAJOR < 21 if (aTy.getElementType().isFloat8E5M2() && bTy.getElementType().isFloat8E5M2()) +#else // triton_v3.3.x + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) +#endif return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; +#if LLVM_VERSION_MAJOR < 21 if (aTy.getElementType().isFloat8E5M2() && bTy.getElementType().isFloat8E4M3FN()) +#else // triton_v3.3.x + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) +#endif return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; +#if LLVM_VERSION_MAJOR < 21 if (aTy.getElementType().isFloat8E4M3FN() && bTy.getElementType().isFloat8E5M2()) +#else // triton_v3.3.x + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) +#endif return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; +#if LLVM_VERSION_MAJOR < 21 if (aTy.getElementType().isFloat8E4M3FN() && bTy.getElementType().isFloat8E4M3FN()) +#else // triton_v3.3.x + if (llvm::isa(aTy.getElementType()) && + llvm::isa(bTy.getElementType())) +#endif return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 1bb55373e..0fb8182d7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -56,9 +56,17 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::tf32; } else if (aTy.isInteger(8)) { return triton::nvgpu::WGMMAEltType::s8; +#if LLVM_VERSION_MAJOR < 21 } else if (aTy.isFloat8E5M2()) { +#else // triton_v3.3.x + } else if (llvm::isa(aTy)) { +#endif return triton::nvgpu::WGMMAEltType::e5m2; +#if LLVM_VERSION_MAJOR < 21 } else if (aTy.isFloat8E4M3FN()) { +#else // triton_v3.3.x + } else if (llvm::isa(aTy)) { +#endif return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index ef69b96fc..3f9d7a502 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -445,8 +445,13 @@ struct FpToFpOpConversion llvm::errs() << "\n"; llvm::report_fatal_error("Unsupported rounding mode for conversion."); } +#if LLVM_VERSION_MAJOR < 21 if (computeCapability < 89 && (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { +#else // triton_v3.3.x + if (computeCapability < 89 && (llvm::isa(srcTy) || + llvm::isa(dstTy))) { +#endif llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " "compute capability >= 89" << "\n"; @@ -468,7 +473,11 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); +#if LLVM_VERSION_MAJOR < 21 if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { +#else // triton_v3.3.x + if (llvm::isa(dstElementType)) { +#endif assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -505,8 +514,13 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && +#if LLVM_VERSION_MAJOR < 21 (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || dstElementType.isFloat8E5M2())) || +#else // triton_v3.3.x + (!(computeCapability >= 90 && + (llvm::isa(dstElementType))) || +#endif roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;