diff --git a/CondaPkg.toml b/CondaPkg.toml index b1db4f8e75..00aa12cb4a 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4" jax = ">= 0.6" tensorflow = ">= 2.17" numpy = ">= 2" +triton = ">= 3.4" diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f8ad3715c0..51e038e816 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -812,6 +812,26 @@ REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops, jlprops->maxThreadsPerMultiProcessor = props.maxThreadsPerMultiProcessor; } +REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary( + const char *binary, const char *fnname, int32_t *regs, int32_t *spills, + int32_t *maxThreads) { + CUfunction fun; + CUmodule mod; + + ReactantHandleCuResult(cuModuleLoadData(&mod, binary)); + ReactantHandleCuResult(cuModuleGetFunction(&fun, mod, fnname)); + + ReactantHandleCuResult( + cuFuncGetAttribute(regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + ReactantHandleCuResult( + cuFuncGetAttribute(spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + *spills /= 4; + ReactantHandleCuResult(cuFuncGetAttribute( + maxThreads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun)); + + return; +} + #else REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; } @@ -827,6 +847,10 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return 0; } REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops, int32_t device_id) {} +REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary( + const char *binary, const char *fnname, int32_t *regs, int32_t *spills, + int32_t *maxThreads) {} + #endif REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) { diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 4f3ce35cfe..f44e491c67 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -984,6 +984,7 @@ cc_library( "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor", "-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads", "-Wl,-exported_symbol,_ReactantCudaDeviceGetProperties", + "-Wl,-exported_symbol,_ReactantCudaGetRegsSpillsMaxThreadsFromBinary", "-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId", "-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId", "-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId", @@ -1436,6 +1437,24 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "TritonExtJLIncGen", + tbl_outs = [ + ( + [ + "--generator=jl-op-defs", + "--disable-module-wrap=0", + ], + "TritonExt.jl", + ), + ], + tblgen = "//:mlir-jl-tblgen", + td_file = "@enzyme_ad//src/enzyme_ad/jax:Dialect/TritonExt/Ops.td", + deps = [ + "@enzyme_ad//src/enzyme_ad/jax:TritonExtDialectTdFiles", + ], +) + gentbl_cc_library( name = "TPUJLIncGen", tbl_outs = [ diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 7a06caeac6..24f589495d 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,8 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "6137ac98e710adf6f4e953bf441db4e25b2db40f" - +ENZYMEXLA_COMMIT = "4d71da26119a84662cd6f5252a68a35ca1673eae" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index f84309fef1..9e4295e9cb 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -42,6 +42,7 @@ for file in [ "MPI.jl", "MemRef.jl", "SparseTensor.jl", + "TritonExt.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index dacce466fb..2853e7abd5 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -131,6 +131,7 @@ export default defineConfig({ { text: "SparseTensor", link: "/api/dialects/sparsetensor" }, { text: "StableHLO", link: "/api/dialects/stablehlo" }, { text: "Triton", link: "/api/dialects/triton" }, + { text: "TritonExt", link: "/api/dialects/tritonext" }, { text: "TPU", link: "/api/dialects/tpu" }, { text: "VHLO", link: "/api/dialects/vhlo" }, ], @@ -221,6 +222,7 @@ export default defineConfig({ { text: "SparseTensor", link: "/api/dialects/sparsetensor" }, { text: "StableHLO", link: "/api/dialects/stablehlo" }, { text: "Triton", link: "/api/dialects/triton" }, + { text: "TritonExt", link: "/api/dialects/tritonext" }, { text: "TPU", link: "/api/dialects/tpu" }, { text: "VHLO", link: "/api/dialects/vhlo" }, ], diff --git a/docs/src/api/dialects/tritonext.md b/docs/src/api/dialects/tritonext.md new file mode 100644 index 0000000000..a727f0dfbb --- /dev/null +++ b/docs/src/api/dialects/tritonext.md @@ -0,0 +1,11 @@ +```@meta +CollapsedDocStrings = true +``` + +# TritonExt Dialect + +Provides extensions to the Triton dialect. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.triton_ext] +``` diff --git a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl index 1f10630808..af3852ce2e 100644 --- a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl @@ -1,14 +1,20 @@ module ReactantPythonCallExt -using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist +using PythonCall: + PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance, pytuple using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay using Reactant.Ops: @opcall +using Reactant_jll: Reactant_jll const jaxptr = Ref{Py}() const jnpptr = Ref{Py}() const JAX_TRACING_SUPPORTED = Ref{Bool}(false) +const tritonptr = Ref{Py}() + +const TRITON_COMPILE_SUPPORTED = Ref{Bool}(false) + const tfptr = Ref{Py}() const tf2xlaptr = Ref{Py}() const npptr = Ref{Py}() @@ -33,6 +39,28 @@ const NUMPY_SIMPLE_TYPES = Dict( ComplexF64 => :complex64, ) +const MLIR_TYPE_STRING = Dict( + Float64 => "fp64", + Float32 => "fp32", + Float16 => "fp16", + Int64 => "i64", + Int32 => "i32", + Int16 => "i16", + Int8 => "i8", + UInt64 => "ui64", + UInt32 => "ui32", + UInt16 => "ui16", + UInt8 => "ui8", + Bool => "i1", + Reactant.F8E4M3FN => "fp8e4nv", + Reactant.F8E5M2FNUZ => "fp8e5b16", + Reactant.F8E4M3FNUZ => "fp8e4b8", + Reactant.F8E5M2 => "fp8e5", +) +if isdefined(Core, :BFloat16) + MLIR_TYPE_STRING[Core.BFloat16] = "bf16" +end + function __init__() try jaxptr[] = pyimport("jax") @@ -43,6 +71,14 @@ function __init__() be supported." exception = (err, catch_backtrace()) end + try + tritonptr[] = pyimport("triton") + TRITON_COMPILE_SUPPORTED[] = true + catch err + @warn "Failed to import triton. Compiling jax functions with triton won't be \ + supported." exception = (err, catch_backtrace()) + end + try tfptr[] = pyimport("tensorflow") tfptr[].config.set_visible_devices(pylist(); device_type="GPU") diff --git a/ext/ReactantPythonCallExt/overlays.jl b/ext/ReactantPythonCallExt/overlays.jl index 20ffa7384f..ca5bcfcea5 100644 --- a/ext/ReactantPythonCallExt/overlays.jl +++ b/ext/ReactantPythonCallExt/overlays.jl @@ -1,7 +1,7 @@ -@reactant_overlay function PythonCall.pycall(f::Py, args...) +@reactant_overlay function PythonCall.pycall(f::Py, args...; kwargs...) if Reactant.looped_any(Reactant.use_overlayed_version, args) - return pycall_with_jax_tracing(f, args...) + return overlayed_pycall(f, args...; kwargs...) else - return Base.inferencebarrier(PythonCall.pycall)(f, args...) + return Base.inferencebarrier(PythonCall.pycall)(f, args...; kwargs...) end end diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 23674d9155..f328f6d4ac 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -7,7 +7,18 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe ) end -function pycall_with_jax_tracing(f::Py, args...) +function overlayed_pycall(f::Py, args...; kwargs...) + @assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[] + # TODO: check for Autotuner and Heutistics as well + if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction) + return overlayed_pycall_with_triton(f, args...; kwargs...) + else + @assert isempty(kwargs) "`kwargs` are not supported for jax traced functions." + return overlayed_pycall_with_jax_tracing(f, args...) + end +end + +function overlayed_pycall_with_jax_tracing(f::Py, args...) JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.") seen_args = Reactant.OrderedIdDict() @@ -35,3 +46,139 @@ function pycall_with_jax_tracing(f::Py, args...) res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end + +struct TritonMetadata{CK,MD,DP} + compiled_kernel::CK + metadata::MD + device_properties::DP + num_warps::Int + num_stages::Int + num_ctas::Int + num_regs::Int + num_spills::Int + max_num_threads::Int +end + +canonicalize_grid(grid_fn, metadata) = canonicalize_grid(grid_fn(metadata), metadata) +canonicalize_grid(grid::Integer, metadata) = canonicalize_grid((grid,), metadata) +function canonicalize_grid(grid::Dims{N}, metadata) where {N} + @assert N <= 3 + @assert all(grid .> 0) + return (grid..., ntuple(_ -> 1, 3 - N)...) +end + +signature_string(::TracedRArray{T}) where {T} = "*$(MLIR_TYPE_STRING[T])", nothing +signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothing +signature_string(x::T) where {T<:Number} = string(x), x +signature_string(x) = error("Unsupported argument type: $(typeof(x))") + +# TODO: better name for hints? +function overlayed_pycall_with_triton( + kernel::Py, + args...; + grid, + num_warps::Integer=4, + num_stages::Integer=3, + num_ctas::Integer=1, + hints=nothing, +) + @assert num_ctas == 1 "TODO: num_ctas > 1 not supported" + triton = tritonptr[] + + mapped = map(signature_string, args) + signature = first.(mapped) + # TODO: are hints actually correctly set? + hints = + hints === nothing ? Dict() : Dict(kernel.arg_names[i - 1] => v for (i, v) in hints) + constants = Dict( + kernel.arg_names[i - 1] => constant for + (i, constant) in enumerate(last.(mapped)) if constant !== nothing + ) + for (k, v) in hints + v == 1 && (constants[kernel.arg_names[k - 1]] = v) + end + attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16) + + sigmap = Dict(kernel.arg_names[i - 1] => sig for (i, sig) in enumerate(signature)) + for k in keys(constants) + sigmap[k] = "constexpr" + end + + for h in values(hints) + @assert h in (1, 16) "Only 1 and 16 are valid hints, got $h" + end + attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16) + + src = triton.compiler.ASTSource(; + fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs + ) + + # TODO: pass the device/client here from `compile` + # TODO: cluster dims + client = Reactant.XLA.default_backend() + @assert Reactant.XLA.platform_name(client) == "cuda" + device = Reactant.XLA.default_device(client) + device_properties = Reactant.XLA.device_properties(device) + + target = triton.backends.compiler.GPUTarget( + Reactant.XLA.platform_name(client), + parse(Int, "$(device_properties.major)$(device_properties.minor)"), + device_properties.warp_size, + ) + backend = triton.compiler.make_backend(target) + options = backend.parse_options( + pydict( + "num_warps" => num_warps, + "num_stages" => num_stages, + "num_ctas" => num_ctas, + "extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)), + ), + ) + + # Currently we are doing a double compilation here. can we do better? + # we are compiling here + lowering again inside enzymejax + compiled_kernel = triton.compile(src; target=target, options=options.__dict__) + + cubin = pyconvert(Vector{UInt8}, compiled_kernel.asm["cubin"]) + fname = pyconvert(String, compiled_kernel.metadata.name) + n_regs, n_spills, n_max_threads = Ref{Int32}(), Ref{Int32}(), Ref{Int32}() + GC.@preserve cubin fname n_regs n_spills n_max_threads begin + @ccall Reactant.MLIR.API.mlir_c.ReactantCudaGetRegsSpillsMaxThreadsFromBinary( + cubin::Ptr{Cvoid}, + fname::Cstring, + n_regs::Ptr{Int32}, + n_spills::Ptr{Int32}, + n_max_threads::Ptr{Int32}, + )::Cvoid + end + + metadata = TritonMetadata( + compiled_kernel, + compiled_kernel.metadata, + device_properties, + num_warps, + num_stages, + num_ctas, + Int(n_regs[]), + Int(n_spills[]), + Int(n_max_threads[]), + ) + + grid = canonicalize_grid(grid, metadata) + + return @opcall triton_call( + pyconvert(String, compiled_kernel.asm["source"]), + filter(x -> x isa Reactant.TracedType, args)...; + func_name=fname, + grid_x=@opcall(constant(grid[1])), + grid_y=@opcall(constant(grid[2])), + grid_z=@opcall(constant(grid[3])), + block_x=@opcall(constant(num_warps * device_properties.warp_size)), + block_y=@opcall(constant(1)), + block_z=@opcall(constant(1)), + num_ctas, + num_warps, + threads_per_warp=device_properties.warp_size, + enable_source_remat=false, + ) +end diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index e8cac78be6..e545aa6550 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -229,6 +229,8 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :no_triton, + :before_triton_lowering, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index a6be948b1c..6c27898574 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -702,6 +702,8 @@ function optimization_passes( lower_comms::Bool=true, max_constant_threshold::Int=1024, backend::String="gpu", + enable_triton_passes::Bool=false, + device_properties::Union{Nothing,XLA.DeviceProperties}=nothing, ) transform_passes_list = [ "patterns=compare_op_canon<16>", @@ -1300,9 +1302,103 @@ function optimization_passes( push!(passes, "remove-duplicate-func-def") end push!(passes, func_passes) + if enable_triton_passes && backend == "cuda" + push!(passes, triton_optimization_passes(device_properties)) + end return join(passes, ',') end +# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc +# To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes +function triton_optimization_passes(device_properties) + @assert device_properties !== nothing "Device properties must be provided to run \ + triton passes. This might happen if you are \ + compiling a triton kernel for non-cuda backend." + major_version = device_properties.major + minor_version = device_properties.minor + + all_passes = join( + [ + "canonicalize", + "triton-rewrite-tensor-pointer", + "canonicalize", + "triton-combine", + "triton-reorder-broadcast", + "cse", + "symbol-dce", + "triton-loop-unroll", + "convert-triton-to-triton-gpu-preserving-module-attributes{target=cuda:$(major_version)$(minor_version)}", + "tritongpu-coalesce", + "tritongpu-F32DotTC", + "triton-nvidia-gpu-plan-cta", + "tritongpu-remove-layout-conversions", + "tritongpu-optimize-thread-locality", + "tritongpu-accelerate-matmul", + "tritongpu-remove-layout-conversions", + "tritongpu-optimize-dot-operands", + "canonicalize", + "triton-nvidia-optimize-descriptor-encoding", + "triton-loop-aware-cse", + "tritongpu-fuse-nested-loops", + "canonicalize", + "triton-licm", + "tritongpu-optimize-accumulator-init", + "tritongpu-hoist-tmem-alloc", + "tritongpu-promote-lhs-to-tmem", + "tritongpu-assign-latencies", + "tritongpu-schedule-loops", + "tritongpu-automatic-warp-specialization", + "tritongpu-partition-scheduling", + "tritongpu-load-mma-specialization", + "tritongpu-rewrite-partition-dependencies", + "sccp", + "cse", + "tritongpu-partition-loops", + "tritongpu-optimize-partition-warps", + "tritongpu-schedule-loops", + "tritongpu-pipeline", + "tritongpu-combine-tensor-select-and-if", + "triton-nvidia-gpu-remove-tmem-tokens", + "canonicalize", + "triton-loop-aware-cse", + "tritongpu-prefetch", + "tritongpu-optimize-dot-operands", + "canonicalize", + "tritongpu-coalesce-async-copy", + "triton-nvidia-optimize-tmem-layouts", + "tritongpu-remove-layout-conversions", + "triton-nvidia-interleave-tmem", + "tritongpu-reduce-data-duplication", + "tritongpu-reorder-instructions", + "triton-loop-aware-cse", + "symbol-dce", + "triton-nvidia-tma-lowering", + "triton-nvidia-gpu-fence-insertion", + "sccp", + "canonicalize", + "triton-nvidia-mma-lowering", + "tritongpu-combine-tensor-select-and-if", + "tritongpu-allocate-warp-groups", + "convert-scf-to-cf", + "allocate-shared-memory", + "triton-tensor-memory-allocation", + "tritongpu-global-scratch-memory-allocation", + "convert-triton-gpu-to-llvm", + "canonicalize", + "cse", + "convert-nv-gpu-to-llvm", + "convert-warp-specialize-to-llvm", + "reconcile-unrealized-casts", + "canonicalize", + "cse", + "symbol-dce", + "enable-line-info", + ], + ",", + ) + return "triton_ext.module(builtin.module($(all_passes)))" +end + # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" @@ -1395,7 +1491,8 @@ function __get_compile_options_and_kwargs(; end function compile_mlir(f, args; client=nothing, kwargs...) - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -1414,6 +1511,7 @@ function compile_mlir(f, args; client=nothing, kwargs...) compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) @@ -1430,11 +1528,9 @@ end const PartitionKA = Ref{Bool}(true) -const cubinChip = Ref{String}("sm_60") -const cubinFormat = Ref{String}("bin") const cuindexBitWidth = Ref{Int}(32) +const cubinFormat = Ref{String}("bin") const cuOptLevel = Ref{Int}(2) -const cuWarpSize = Ref{Int}(32) # Wgatever the relevant highest version from our LLVM is within NVPTX.td # Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684 @@ -1580,8 +1676,11 @@ function compile_mlir!( backend="gpu", runtime::Union{Val{:PJRT},Val{:IFRT}}, legalize_stablehlo_to_mhlo::Bool=false, + client=nothing, kwargs..., ) + client = client !== nothing ? client : XLA.default_backend() + # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues @@ -1648,6 +1747,9 @@ function compile_mlir!( toolkit = XLA.CUDA_DATA_DIR[] + default_device = XLA.default_device(client) + device_properties = XLA.device_properties(default_device) + if backend == "cpu" || backend == "tpu" kern = "lower-kernel{backend=cpu},canonicalize" if backend == "tpu" @@ -1655,25 +1757,25 @@ function compile_mlir!( else jit = "lower-jit{openmp=$(OpenMP[]) backend=cpu},symbol-dce" end - elseif DEBUG_KERNEL[] - curesulthandler = dlsym( - Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" - ) - @assert curesulthandler !== nothing - curesulthandler = Base.reinterpret(UInt, curesulthandler) + else kern = if is_raising "lower-kernel{backend=cpu},symbol-dce,canonicalize" else "lower-kernel,canonicalize" end - jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" - else - kern = if is_raising - "lower-kernel{backend=cpu},symbol-dce,canonicalize" + + cubinChip = "sm_$(device_properties.major)$(device_properties.minor)" + if DEBUG_KERNEL[] + curesulthandler = dlsym( + Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" + ) + @assert curesulthandler !== nothing + curesulthandler = Base.reinterpret(UInt, curesulthandler) + extra_lowerjit_options = "debug=true cuResultHandlerPtr=$curesulthandler " else - "lower-kernel,canonicalize" + extra_lowerjit_options = "" end - jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" + jit = "lower-jit{$(extra_lowerjit_options)cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" end recognize_comms = true @@ -1687,10 +1789,31 @@ function compile_mlir!( end opt_passes = optimization_passes( - compile_options; sroa=true, recognize_comms, lower_comms, backend + compile_options; + sroa=true, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=false, + device_properties, ) opt_passes2 = optimization_passes( - compile_options; sroa=false, recognize_comms, lower_comms, backend + compile_options; + sroa=false, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=false, + device_properties, + ) + opt_passes_with_triton = optimization_passes( + compile_options; + sroa=false, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=true, + device_properties, ) raise_passes = if raise isa String @@ -1705,15 +1828,16 @@ function compile_mlir!( opt_passes2 if DUS_TO_CONCAT[] - opt_passes3 = optimization_passes( + opt_passes_dus_to_concat = optimization_passes( compile_options; sroa=false, dus_to_concat=true, recognize_comms, lower_comms, backend, + device_properties, ) - result = result * "," * opt_passes3 + result = result * "," * opt_passes_dus_to_concat end result else @@ -1739,6 +1863,8 @@ function compile_mlir!( [ "mark-func-memory-effects", opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", @@ -1760,12 +1886,13 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, @@ -1776,7 +1903,7 @@ function compile_mlir!( ), "all", ) - elseif compile_options.optimization_passes === :before_kernel + elseif compile_options.optimization_passes === :no_triton run_pass_pipeline!( mod, join( @@ -1799,6 +1926,57 @@ function compile_mlir!( end, ',', ), + "no_triton", + ) + elseif compile_options.optimization_passes === :before_triton_lowering + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + ["mark-func-memory-effects", opt_passes] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + opt_passes_with_triton, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + legalize_chlo_to_stablehlo..., + opt_passes2, + ] + end, + ',', + ), + "before_triton_lowering", + ) + elseif compile_options.optimization_passes === :before_kernel + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + ["mark-func-memory-effects", opt_passes] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + opt_passes_with_triton, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + legalize_chlo_to_stablehlo..., + opt_passes2, + "lower-triton", + ] + end, + ',', + ), "before_kernel", ) elseif compile_options.optimization_passes === :before_jit @@ -1808,7 +1986,8 @@ function compile_mlir!( if compile_options.raise_first [ "mark-func-memory-effects", - opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", @@ -1828,12 +2007,13 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, raise_passes, ] @@ -1855,12 +2035,13 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, ] end, @@ -1878,7 +2059,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", @@ -1920,7 +2101,7 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes_with_triton, lower_enzymexla_linalg_pass, jit, ] @@ -1933,7 +2114,8 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, @@ -1951,7 +2133,8 @@ function compile_mlir!( if compile_options.raise_first [ "mark-func-memory-effects", - opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", @@ -1966,9 +2149,10 @@ function compile_mlir!( "mark-func-memory-effects", opt_passes, "enzyme-batch", - opt_passes2, + opt_passes_with_triton, enzyme_pass, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math", + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, @@ -2002,6 +2186,7 @@ function compile_mlir!( recognize_comms, lower_comms, backend, + device_properties, ), "post_op_transpose_reshape", ) @@ -3477,7 +3662,8 @@ function compile_xla( context_gc_vector[ctx] = Vector{Union{TracedRArray,TracedRNumber}}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -3498,6 +3684,7 @@ function compile_xla( compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) diff --git a/src/Ops.jl b/src/Ops.jl index 252d87dc84..79971abf3e 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3,7 +3,7 @@ # Julia and Reactant semantics should be considered on the higher abstractions that use these module Ops using ..MLIR: MLIR -using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla +using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla, triton_ext using ..Reactant: Reactant, TracedRArray, @@ -1743,8 +1743,175 @@ end end # Generate a unique name given a module hash and a function name. -function _hlo_call_name(orig_name, module_suffix) - return orig_name * "_hlo_call_" * module_suffix +_new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix + +function _extract_function( + code::String; + func_name::String="main", + func_op_kind::String="func.func", + location::MLIR.IR.Location=MLIR.IR.Location(), +) + module_suffix = string(hash(code); base=16) + name_to_call = func_name * "_call_" * module_suffix + mod_name = func_name * "_module_" * module_suffix + symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) + + use_ttext_module = split(func_op_kind, ".")[1] == "tt" + + if use_ttext_module + tt_mod_name = func_name * "_tt_module_" * module_suffix + tt_region = MLIR.IR.Region() + tt_block = MLIR.IR.Block() + push!(tt_region, tt_block) + triton_mod_op = triton_ext.module_(; + location, bodyRegion=tt_region, sym_name=tt_mod_name + ) + MLIR.IR.rmfromparent!(triton_mod_op) + push!(MLIR.IR.body(MLIR.IR.mmodule()), triton_mod_op) # insert into parent module + + region = MLIR.IR.Region() + push!(region, MLIR.IR.Block()) + moduleop = MLIR.Dialects.builtin.module_(; + location, bodyRegion=region, sym_name=mod_name + ) + MLIR.IR.rmfromparent!(moduleop) + push!(tt_block, moduleop) # insert into triton module + + top_level_block = MLIR.IR.Block( + MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false + ) + fn = nothing + + symref = MLIR.IR.SymbolRefAttribute( + tt_mod_name, + MLIR.IR.Attribute[ + MLIR.IR.FlatSymbolRefAttribute(mod_name), + MLIR.IR.FlatSymbolRefAttribute(name_to_call), + ], + ) + else + current_module = MLIR.IR.mmodule() + moduleop = MLIR.IR.Operation(current_module) + top_level_block = MLIR.IR.body(current_module) + fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call) + symref = MLIR.IR.FlatSymbolRefAttribute(name_to_call) + end + + if isnothing(fn) + new_mod = parse(MLIR.IR.Module, code) + new_mod_op = MLIR.IR.Operation(new_mod) + body = MLIR.IR.body(new_mod) + + operations = collect(MLIR.IR.OperationIterator(body)) + idx = Base.findfirst(op -> MLIR.IR.name(op) == func_op_kind, operations) + @assert idx !== nothing + op = operations[idx] + + fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) + fn_name == func_name && (fn = op) + + res = MLIR.IR.LogicalResult( + MLIR.API.mlirSymbolTableReplaceAllSymbolUses(fn_name, name_to_call, new_mod_op) + ) + @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" + + if !use_ttext_module + # Set function private + MLIR.IR.attr!( + op, + MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), + MLIR.IR.Attribute("private"), + ) + end + + # Change function name + MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) + + for op in operations + MLIR.IR.rmfromparent!(op) + push!(top_level_block, op) + end + end + + if isnothing(fn) + error("hlo_call: could not find function $func_name in the provided module") + end + + return fn, symref, moduleop +end + +function triton_call( + mlir_code::String, + args::Union{TracedRArray,TracedRNumber,Number}...; + func_name::String="main", + grid_x::TracedRNumber{<:Integer}, + grid_y::TracedRNumber{<:Integer}, + grid_z::TracedRNumber{<:Integer}, + block_x::TracedRNumber{<:Integer}, + block_y::TracedRNumber{<:Integer}, + block_z::TracedRNumber{<:Integer}, + num_ctas::Integer=1, + num_warps::Integer=4, + threads_per_warp::Integer=32, + enable_source_remat::Bool=false, + location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), +) + _, symref, modop = _extract_function( + mlir_code; func_name, func_op_kind="tt.func", location + ) + + MLIR.IR.attr!(modop, "enzymexla.ttg.num-warps", MLIR.IR.Attribute(Int32(num_warps))) + MLIR.IR.attr!(modop, "enzymexla.ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas))) + MLIR.IR.attr!( + modop, "enzymexla.ttg.threads-per-warp", MLIR.IR.Attribute(Int32(threads_per_warp)) + ) + if enable_source_remat + MLIR.IR.attr!(modop, "enzymexla.ttg.enable-source-remat", MLIR.IR.UnitAttribute()) + end + + result_types = MLIR.IR.Type[] + output_operand_aliases = MLIR.IR.Attribute[] + output_to_arg = Int[] + for (i, arg) in enumerate(args) + if arg isa TracedRArray + push!(result_types, mlir_type(typeof(arg), size(arg))) + push!( + output_operand_aliases, + MLIR.IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 1, Int64[i - 1], Int64(i - 1), 0, C_NULL + ), + ), + ) + push!(output_to_arg, i) + end + end + + results = triton_ext.call( + grid_x.mlir_data, + grid_y.mlir_data, + grid_z.mlir_data, + block_x.mlir_data, + block_y.mlir_data, + block_z.mlir_data, + [Reactant.TracedUtils.get_mlir_data(a) for a in args]; + fn=symref, + result_0=result_types, + location, + output_operand_aliases, + ) + + array_results = () + for i in 1:MLIR.IR.nresults(results) + arg = args[output_to_arg[i]] + res = Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}( + (), MLIR.IR.result(results, i), size(arg) + ) + copyto!(arg, res) + array_results = (array_results..., res) + end + length(array_results) == 1 && return array_results[1] + return array_results end """ @@ -1773,69 +1940,16 @@ julia> Reactant.@jit( """ @noinline function hlo_call( code, - args...; + args::Union{TracedRArray,TracedRNumber}...; func_name="main", location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), ) - module_suffix = string(hash(code); base=16) - name_to_call = _hlo_call_name(func_name, module_suffix) - - current_module = MLIR.IR.mmodule() - top_level_block = MLIR.IR.body(current_module) - - symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - - fn = MLIR.IR.lookup( - MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call - ) - if isnothing(fn) - new_mod = parse(MLIR.IR.Module, code) - new_mod_op = MLIR.IR.Operation(new_mod) - body = MLIR.IR.body(new_mod) - - operations = collect(MLIR.IR.OperationIterator(body)) - for op in operations - if MLIR.IR.name(op) == "func.func" - fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) - if fn_name == func_name - fn = op - end - - new_name = _hlo_call_name(fn_name, module_suffix) - res = MLIR.IR.LogicalResult( - MLIR.API.mlirSymbolTableReplaceAllSymbolUses( - fn_name, new_name, new_mod_op - ), - ) - @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" - - # Set function private - MLIR.IR.attr!( - op, - MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), - MLIR.IR.Attribute("private"), - ) - - # Change function name - MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name)) - end - end - - for op in operations - MLIR.IR.rmfromparent!(op) - push!(top_level_block, op) - end - end - - if isnothing(fn) - error("hlo_call: could not find function $func_name in the provided module") - end + fn, symref, _ = _extract_function(code; func_name, func_op_kind="func.func", location) ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) - @assert all(Base.Fix2(isa, Union{TracedRArray,TracedRNumber}), args) "hlo_call: all inputs to hlo_call should be reactant arrays or numbers" - @assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name" + @assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name. Expected $(MLIR.IR.ninputs(ftype)), got $(length(args))" for (i, arg) in enumerate(args) expected_type = MLIR.IR.input(ftype, i) @@ -1847,7 +1961,7 @@ julia> Reactant.@jit( call = MLIR.Dialects.func.call( operands; result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], - callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call), + callee=symref, location, ) diff --git a/src/Reactant.jl b/src/Reactant.jl index 7c31f1a8c5..3922f0761f 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -236,6 +236,35 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") +""" + rowmajor_strides(x::AbstractArray) + +Returns the strides of the array `x` assuming that the array is stored in row-major order. +""" +rowmajor_strides(x::AbstractArray) = rowmajor_strides(size(x)) +function rowmajor_strides(sz::NTuple{N,Int}) where {N} + strides = ntuple(_ -> 1, N) + for i in (N - 1):-1:1 + strides = Base.setindex(strides, strides[i + 1] * sz[i + 1], i) + end + return strides +end + +""" + rowmajor_stride(x::AbstractArray, i::Integer) + +Returns the stride of the array `x` at dimension `i` assuming that the array is stored in +row-major order. +""" +rowmajor_stride(x::AbstractArray, i::Integer) = rowmajor_stride(size(x), i) +function rowmajor_stride(sz::NTuple{N,Int}, i::Integer) where {N} + s = 1 + for j in (i + 1):N + s *= sz[j] + end + return s +end + export StackedBatchDuplicated, StackedBatchDuplicatedNoNeed const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/src/mlir/Dialects/TritonExt.jl b/src/mlir/Dialects/TritonExt.jl new file mode 100644 index 0000000000..bb79bade44 --- /dev/null +++ b/src/mlir/Dialects/TritonExt.jl @@ -0,0 +1,84 @@ +module triton_ext +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +function call( + gridx::Value, + gridy::Value, + gridz::Value, + blockx::Value, + blocky::Value, + blockz::Value, + inputs::Vector{Value}; + result_0::Vector{IR.Type}, + fn, + backend_config=nothing, + operand_layouts=nothing, + result_layouts=nothing, + arg_attrs=nothing, + res_attrs=nothing, + output_operand_aliases=nothing, + xla_side_effect_free=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, inputs...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(backend_config) && + push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && + push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && + push!(attributes, namedattribute("result_layouts", result_layouts)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(output_operand_aliases) && + push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && + push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + return create_operation( + "triton_ext.call", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function module_(; sym_name, bodyRegion::Region, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[bodyRegion,] + successors = Block[] + attributes = NamedAttribute[namedattribute("sym_name", sym_name),] + + return create_operation( + "triton_ext.module", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # triton_ext diff --git a/src/mlir/IR/Module.jl b/src/mlir/IR/Module.jl index 12794b30ba..c7f938d5b8 100644 --- a/src/mlir/IR/Module.jl +++ b/src/mlir/IR/Module.jl @@ -52,7 +52,8 @@ body(module_) = Block(API.mlirModuleGetBody(module_), false) Views the module as a generic operation. """ -Operation(module_::Module) = Operation(API.mlirModuleGetOperation(module_), false) +Operation(module_::Module, owned::Bool=false) = + Operation(API.mlirModuleGetOperation(module_), owned) function Base.show(io::IO, module_::Module) return show(io, Operation(module_)) diff --git a/src/xla/Device.jl b/src/xla/Device.jl index 19e9ef737f..fd76bb6e3e 100644 --- a/src/xla/Device.jl +++ b/src/xla/Device.jl @@ -11,6 +11,7 @@ function device_kind end function default_memory end function memories end function is_addressable end +function get_local_hardware_id end """ device_ordinal(device::Device) @@ -29,3 +30,74 @@ end function is_addressable(device::AbstractDevice) return device ∈ addressable_devices(client(device)) end + +# Keep in sync with API.cpp +struct DeviceProperties + total_global_mem::Csize_t + shared_mem_per_block::Csize_t + regs_per_block::Cint + warp_size::Cint + max_threads_per_block::Cint + max_threads_dim::NTuple{3,Cint} + max_grid_size::NTuple{3,Cint} + total_const_mem::Csize_t + major::Cint + minor::Cint + multi_processor_count::Cint + can_map_host_memory::Cint + l2_cache_size::Cint + max_threads_per_multiprocessor::Cint +end + +const DEVICE_PROPERTIES_CACHE = Dict{Tuple{Int,String},DeviceProperties}() + +""" + device_properties(device::AbstractDevice) + +Get a struct containing device properties. Which exact fields are populated relies on the +underlying device implementation. +""" +function device_properties(device::AbstractDevice) + pname = platform_name(client(device)) + local_hardware_id = get_local_hardware_id(device) + + if haskey(DEVICE_PROPERTIES_CACHE, (local_hardware_id, pname)) + return DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] + end + + jldevprops = Ref{DeviceProperties}() + if pname == "cuda" + GC.@preserve jldevprops begin + @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetProperties( + jldevprops::Ptr{Cvoid}, local_hardware_id::Cint + )::Cvoid + end + else + @warn "`get_properties` not implemented for platform: $(pname)" maxlog = 1 + end + DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] = jldevprops[] + return jldevprops[] +end + +function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties) + return print( + io, + """ + DeviceProperties + ---------------- + Total Global Mem: $(_format_bytes(props.total_global_mem)) + Shared Mem Per Block: $(_format_bytes(props.shared_mem_per_block)) + Regs Per Block: $(props.regs_per_block) + Warp Size: $(props.warp_size) + Max Threads Per Block: $(props.max_threads_per_block) + Max Threads Dim: $(props.max_threads_dim) + Max Grid Size: $(props.max_grid_size) + Total Const Mem: $(_format_bytes(props.total_const_mem)) + Version: $(VersionNumber(props.major, props.minor)) + Multi Processor Count: $(props.multi_processor_count) + Can Map Host Memory: $(props.can_map_host_memory) + L2 Cache Size: $(props.l2_cache_size) + Max Threads Per Multiprocessor: $(props.max_threads_per_multiprocessor) + """, + ) +end diff --git a/src/xla/IFRT/Device.jl b/src/xla/IFRT/Device.jl index 7d269e166c..672900454a 100644 --- a/src/xla/IFRT/Device.jl +++ b/src/xla/IFRT/Device.jl @@ -31,6 +31,14 @@ function XLA.get_local_device_id(::Device) return error("Not implemented for ifrt devices") end +function XLA.get_local_hardware_id(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.ifrt_DeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.default_memory(device::Device) GC.@preserve device begin return Memory( diff --git a/src/xla/PJRT/Device.jl b/src/xla/PJRT/Device.jl index 2a29c6279b..4a4dd178e7 100644 --- a/src/xla/PJRT/Device.jl +++ b/src/xla/PJRT/Device.jl @@ -33,6 +33,14 @@ function XLA.get_local_device_id(device::Device) end end +function XLA.get_local_hardware_id(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.is_addressable(device::Device) GC.@preserve device begin return @ccall MLIR.API.mlir_c.pjrt_device_is_addressable( diff --git a/src/xla/Stats.jl b/src/xla/Stats.jl index bc66cc348a..59f62609c2 100644 --- a/src/xla/Stats.jl +++ b/src/xla/Stats.jl @@ -13,7 +13,7 @@ struct JLAllocatorStats peak_pool_bytes::Int64 end -_format_bytes(x) = Base.format_bytes(x) +_format_bytes(x) = x < 0 ? nothing : Base.format_bytes(x) _format_bytes(x::Nothing) = x """ diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 1a7ffc17f2..f14139b890 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -234,15 +234,6 @@ for runtime in (:PJRT, :IFRT) ) state.clients["cuda"] = gpu state.default_client = gpu - - # set values for cuda. This is being done here since we need cuda - # to be initialized before we can use it. initializing the devices - # implicitly initializes cuda. - cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32 - cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32 - Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)" - - Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32 catch e println(stdout, e) end diff --git a/test/integration/triton/layer_norm.jl b/test/integration/triton/layer_norm.jl new file mode 100644 index 0000000000..f9652da235 --- /dev/null +++ b/test/integration/triton/layer_norm.jl @@ -0,0 +1,71 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +layer_norm_kernel = pyimport("layer_norm").layer_norm_fwd_fused +layer_norm_kernel_v2 = pyimport("layer_norm").layer_norm_fwd_fused_simple + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function layer_norm_triton( + x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T}, simple::Bool +) where {T} + x_transposed = permutedims(x, (2, 1)) # match python array layout + y = similar(x_transposed) + M, N = size(x_transposed) + mean = similar(x_transposed, Float32, M) + rstd = similar(x_transposed, Float32, M) + + max_fused_size = 65536 ÷ sizeof(T) + block_size = min(max_fused_size, nextpow(2, N)) + + if N > block_size + throw(ArgumentError("This layer norm doesn't support feature dim >= 64KB.")) + end + + (simple ? layer_norm_kernel_v2 : layer_norm_kernel)( + x_transposed, + y, + weight, + bias, + mean, + rstd, + Reactant.rowmajor_stride(x_transposed, 1), + N, + 1.0f-5, + block_size; + num_warps=min(max(block_size ÷ 256, 1), 8), + num_ctas=1, + grid=(M,), + ) + + return permutedims(y, (2, 1)), mean, rstd +end + +function layer_norm_naive( + x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T} +) where {T} + mean = sum(x; dims=1) ./ size(x, 1) + rstd = 1 ./ sqrt.(sum(abs2, x .- mean; dims=1) ./ size(x, 1) .+ 1e-5) + x_hat = (x .- mean) .* rstd + return x_hat .* weight .+ bias, vec(mean), vec(rstd) +end + +@testset "fused_layer_norm" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 257, 2056)) + weight_ra = Reactant.to_rarray(rand(Float32, 257)) + bias_ra = Reactant.to_rarray(rand(Float32, 257)) + + y_ra1, mean_ra1, rstd_ra1 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, false) + y_ra2, mean_ra2, rstd_ra2 = @jit layer_norm_naive(x_ra, weight_ra, bias_ra) + y_ra3, mean_ra3, rstd_ra3 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, true) + + @test y_ra1 ≈ y_ra2 + @test y_ra2 ≈ y_ra3 + @test mean_ra1 ≈ mean_ra2 + @test mean_ra2 ≈ mean_ra3 + @test rstd_ra1 ≈ rstd_ra2 + @test rstd_ra2 ≈ rstd_ra3 + end +end diff --git a/test/integration/triton/layer_norm.py b/test/integration/triton/layer_norm.py new file mode 100644 index 0000000000..9595491551 --- /dev/null +++ b/test/integration/triton/layer_norm.py @@ -0,0 +1,103 @@ +import triton +import triton.language as tl + + +@triton.jit +def layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +@triton.jit +def layer_norm_fwd_fused_simple( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + + # Compute mean - process one element at a time + mean = 0.0 + for i in range(N): + x = tl.load(X + i).to(tl.float32) + mean += x + mean = mean / N + + # Compute variance - process one element at a time + var = 0.0 + for i in range(N): + x = tl.load(X + i).to(tl.float32) + diff = x - mean + var += diff * diff + var = var / N + rstd = 1.0 / tl.sqrt(var + eps) + + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) diff --git a/test/integration/triton/libdevice.jl b/test/integration/triton/libdevice.jl new file mode 100644 index 0000000000..89eee78e99 --- /dev/null +++ b/test/integration/triton/libdevice.jl @@ -0,0 +1,21 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +asin_kernel = pyimport("libdevice").asin_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function asin_triton(x::AbstractVector{T}) where {T} + out = similar(x) + asin_kernel(x, out, length(x), 1024; grid=(cld(length(x), 1024),)) + return out +end + +@testset "libdevice asin" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 2096)) + + @test @jit(asin_triton(x_ra)) ≈ @jit(asin.(x_ra)) + end +end diff --git a/test/integration/triton/libdevice.py b/test/integration/triton/libdevice.py new file mode 100644 index 0000000000..ac9a199952 --- /dev/null +++ b/test/integration/triton/libdevice.py @@ -0,0 +1,19 @@ +import triton +import triton.language as tl +from triton.language.extra import libdevice + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) diff --git a/test/integration/triton/low_memory_dropout.jl b/test/integration/triton/low_memory_dropout.jl new file mode 100644 index 0000000000..48be41490b --- /dev/null +++ b/test/integration/triton/low_memory_dropout.jl @@ -0,0 +1,30 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +low_memory_dropout_kernel = pyimport("low_memory_dropout").seeded_dropout_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function seeded_dropout(x::AbstractVector{T}, p::Number, seed) where {T} + output = similar(x) + mask = similar(x, Bool) + low_memory_dropout_kernel( + x, output, mask, length(x), p, seed, 1024; grid=(cld(length(x), 1024),) + ) + return output, mask +end + +function apply_dropout(x::AbstractVector{T}, mask::AbstractVector, p::Number) where {T} + return x .* mask ./ (1 - p) +end + +@testset "low_memory_dropout" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 2056)) + + out, mask = @jit seeded_dropout(x_ra, 0.25f0, ConcreteRNumber(123)) + + @test @jit(apply_dropout(x_ra, mask, 0.25f0)) ≈ out + end +end diff --git a/test/integration/triton/low_memory_dropout.py b/test/integration/triton/low_memory_dropout.py new file mode 100644 index 0000000000..ad32ac0014 --- /dev/null +++ b/test/integration/triton/low_memory_dropout.py @@ -0,0 +1,29 @@ +import triton +import triton.language as tl + + +@triton.jit +def seeded_dropout_kernel( + x_ptr, + output_ptr, + mask_ptr, + n_elements, + p, + seed, + BLOCK_SIZE: tl.constexpr, +): + # compute memory offsets of elements handled by this instance + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + mask_out = tl.where(x_keep, 1.0, 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + tl.store(mask_ptr + offsets, mask_out, mask=mask) diff --git a/test/integration/triton/matmul.jl b/test/integration/triton/matmul.jl new file mode 100644 index 0000000000..ea841dd771 --- /dev/null +++ b/test/integration/triton/matmul.jl @@ -0,0 +1,61 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +matmul_kernel = pyimport("matmul").matmul_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function matmul_triton(a::AbstractMatrix{T}, b::AbstractMatrix{T}) where {T} + # a: [M, K] --> aᵀ: [K, M] + # b: [K, N] --> bᵀ: [N, K] + # c: a × b [M, N] --> cᵀ: bᵀ × aᵀ [N, M] + a_transposed = permutedims(a, (2, 1)) # match python array layout + b_transposed = permutedims(b, (2, 1)) # match python array layout + @assert size(b_transposed, 2) == size(a_transposed, 1) "Inner dimensions must match \ + for matmul" + M, K = size(b_transposed) + K, N = size(a_transposed) + + out = similar(a_transposed, T, M, N) # cᵀ + + matmul_kernel( + b_transposed, + a_transposed, + out, + M, + N, + K, + Reactant.rowmajor_stride(b_transposed, 1), + Reactant.rowmajor_stride(b_transposed, 2), + Reactant.rowmajor_stride(a_transposed, 1), + Reactant.rowmajor_stride(a_transposed, 2), + Reactant.rowmajor_stride(out, 1), + Reactant.rowmajor_stride(out, 2), + 64, + 256, + 32, + 8; + grid=(cld(M, 64) * cld(N, 256),), + num_stages=4, + num_warps=4, + ) + + return permutedims(out, (2, 1)) +end + +@testset "matmul" begin + if RunningOnCUDA + @testset for M in (4, 32, 256, 1024), + K in (4, 32, 512, 2048), + N in (4, 32, 256, 1024) + + a = Reactant.to_rarray(rand(Float32, M, K)) + b = Reactant.to_rarray(rand(Float32, K, N)) + + # XXX: shared_memory???? + # XXX: seems to work correctly for small matrices + @test_broken @jit(matmul_triton(a, b)) ≈ @jit(a * b) + end + end +end diff --git a/test/integration/triton/matmul.py b/test/integration/triton/matmul.py new file mode 100644 index 0000000000..f4dafc0318 --- /dev/null +++ b/test/integration/triton/matmul.py @@ -0,0 +1,264 @@ +import triton +import triton.language as tl + + +# XXX: enable and support autotuning +# @triton.autotune( +# configs=[ +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 32, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 32, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 32, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=5, +# num_warps=2, +# ), +# # Good config for fp8 inputs. +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 256, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 256, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 32, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# ], +# key=["M", "N", "K"], +# ) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ----------------------------------------------------------- + # Add some integer bound assumptions. + # This helps to guide integer analysis in the backend to optimize + # load/store offset address calculation + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) diff --git a/test/integration/triton/softmax.jl b/test/integration/triton/softmax.jl new file mode 100644 index 0000000000..815f390754 --- /dev/null +++ b/test/integration/triton/softmax.jl @@ -0,0 +1,61 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +softmax_kernel = pyimport("softmax").softmax_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function softmax_naive(x::AbstractMatrix{T}) where {T} + x_max = maximum(x; dims=1) + z = x .- x_max + num = exp.(z) + denom = sum(num; dims=1) + return num ./ denom +end + +function softmax_triton(x::AbstractMatrix{T}) where {T} + x_transposed = permutedims(x, (2, 1)) # match python array layout + out = similar(x_transposed) + n_rows, n_cols = size(x_transposed) + + BLOCK_SIZE = nextpow(2, n_cols) + + function grid_fn(metadata) + occupancy = ( + metadata.device_properties.regs_per_block ÷ + (metadata.num_regs * metadata.device_properties.warp_size * metadata.num_warps) + ) + + num_programs = min( + metadata.device_properties.multi_processor_count * min( + occupancy, + metadata.device_properties.shared_mem_per_block ÷ metadata.metadata.shared, + ), + n_rows, + ) + return num_programs + end + + softmax_kernel( + out, + x_transposed, + Reactant.rowmajor_stride(x_transposed, 1), + Reactant.rowmajor_stride(out, 1), + n_rows, + n_cols, + BLOCK_SIZE, + num_stages=3; + grid=grid_fn, + ) + + return permutedims(out, (2, 1)) +end + +@testset "softmax" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 132, 2056)) + + @test @jit(softmax_triton(x_ra)) ≈ @jit(softmax_naive(x_ra)) + end +end diff --git a/test/integration/triton/softmax.py b/test/integration/triton/softmax.py new file mode 100644 index 0000000000..0a80c43275 --- /dev/null +++ b/test/integration/triton/softmax.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl + + +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + num_stages: tl.constexpr, +): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) diff --git a/test/integration/triton/vector_add.jl b/test/integration/triton/vector_add.jl new file mode 100644 index 0000000000..5a96e3b785 --- /dev/null +++ b/test/integration/triton/vector_add.jl @@ -0,0 +1,22 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +add_kernel = pyimport("vector_add").add_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function vector_add_triton(x::AbstractVector{T}, y::AbstractVector{T}) where {T} + out = similar(x) + add_kernel(x, y, out, length(x), 1024; grid=(cld(length(x), 1024),)) + return out +end + +@testset "vector_add" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 2096)) + y_ra = Reactant.to_rarray(rand(Float32, 2096)) + + @test @jit(vector_add_triton(x_ra, y_ra)) ≈ @jit(x_ra .+ y_ra) + end +end diff --git a/test/integration/triton/vector_add.py b/test/integration/triton/vector_add.py new file mode 100644 index 0000000000..6b04d51a7d --- /dev/null +++ b/test/integration/triton/vector_add.py @@ -0,0 +1,31 @@ +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) diff --git a/test/runtests.jl b/test/runtests.jl index f812deee5c..91389b3231 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,6 +59,24 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) nranks = 2 run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) end + @testset "Triton" begin + @safetestset "vector_add" include("integration/triton/vector_add.jl") + @safetestset "softmax" include("integration/triton/softmax.jl") + # @safetestset "matmul" include("integration/triton/matmul.jl") # XXX + @safetestset "low_memory_dropout" include( + "integration/triton/low_memory_dropout.jl" + ) + @safetestset "layer norm" include("integration/triton/layer_norm.jl") + # @safetestset "attention" include("integration/triton/attention.jl") + @safetestset "libdevice" include("integration/triton/libdevice.jl") + # @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl") + # @safetestset "persistant matmul" include( + # "integration/triton/persistant_matmul.jl" + # ) + # @safetestset "block scaled matmul" include( + # "integration/triton/block_scaled_matmul.jl" + # ) + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"