From 12d0d31999c188d4b7fdc6df86357973e209f48c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 14:30:15 -0500 Subject: [PATCH 01/28] feat: julia api to access device properties [skip ci] --- src/Compiler.jl | 41 +++++++++++++---------- src/xla/Device.jl | 76 ++++++++++++++++++++++++++++++++++++++++++ src/xla/IFRT/Device.jl | 8 +++++ src/xla/PJRT/Device.jl | 8 +++++ src/xla/Stats.jl | 2 +- src/xla/XLA.jl | 9 ----- 6 files changed, 117 insertions(+), 27 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index a6be948b1c..5823c9a07b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1395,7 +1395,8 @@ function __get_compile_options_and_kwargs(; end function compile_mlir(f, args; client=nothing, kwargs...) - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -1414,6 +1415,7 @@ function compile_mlir(f, args; client=nothing, kwargs...) compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) @@ -1430,11 +1432,9 @@ end const PartitionKA = Ref{Bool}(true) -const cubinChip = Ref{String}("sm_60") -const cubinFormat = Ref{String}("bin") const cuindexBitWidth = Ref{Int}(32) +const cubinFormat = Ref{String}("bin") const cuOptLevel = Ref{Int}(2) -const cuWarpSize = Ref{Int}(32) # Wgatever the relevant highest version from our LLVM is within NVPTX.td # Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684 @@ -1580,8 +1580,11 @@ function compile_mlir!( backend="gpu", runtime::Union{Val{:PJRT},Val{:IFRT}}, legalize_stablehlo_to_mhlo::Bool=false, + client=nothing, kwargs..., ) + client = client !== nothing ? client : XLA.default_backend() + # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues @@ -1655,25 +1658,27 @@ function compile_mlir!( else jit = "lower-jit{openmp=$(OpenMP[]) backend=cpu},symbol-dce" end - elseif DEBUG_KERNEL[] - curesulthandler = dlsym( - Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" - ) - @assert curesulthandler !== nothing - curesulthandler = Base.reinterpret(UInt, curesulthandler) + else kern = if is_raising "lower-kernel{backend=cpu},symbol-dce,canonicalize" else "lower-kernel,canonicalize" end - jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" - else - kern = if is_raising - "lower-kernel{backend=cpu},symbol-dce,canonicalize" + + device_properties = XLA.device_properties(XLA.default_device(client)) + cubinChip = "sm_$(device_properties.major)$(device_properties.minor)" + + if DEBUG_KERNEL[] + curesulthandler = dlsym( + Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" + ) + @assert curesulthandler !== nothing + curesulthandler = Base.reinterpret(UInt, curesulthandler) + extra_lowerjit_options = "debug=true cuResultHandlerPtr=$curesulthandler " else - "lower-kernel,canonicalize" + extra_lowerjit_options = "" end - jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" + jit = "lower-jit{$(extra_lowerjit_options)cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" end recognize_comms = true @@ -3477,7 +3482,8 @@ function compile_xla( context_gc_vector[ctx] = Vector{Union{TracedRArray,TracedRNumber}}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -3498,6 +3504,7 @@ function compile_xla( compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) diff --git a/src/xla/Device.jl b/src/xla/Device.jl index 19e9ef737f..5e28cc3ce3 100644 --- a/src/xla/Device.jl +++ b/src/xla/Device.jl @@ -11,6 +11,7 @@ function device_kind end function default_memory end function memories end function is_addressable end +function get_local_hardware_id end """ device_ordinal(device::Device) @@ -29,3 +30,78 @@ end function is_addressable(device::AbstractDevice) return device ∈ addressable_devices(client(device)) end + +# Keep in sync with API.cpp +struct DeviceProperties + total_global_mem::Csize_t + shared_mem_per_block::Csize_t + regs_per_block::Cint + warp_size::Cint + max_threads_per_block::Cint + max_threads_dim::NTuple{3,Cint} + max_grid_size::NTuple{3,Cint} + clock_rate::Cint + total_const_mem::Csize_t + major::Cint + minor::Cint + multi_processor_count::Cint + can_map_host_memory::Cint + compute_mode::Cint + l2_cache_size::Cint + max_threads_per_multiprocessor::Cint +end + +const DEVICE_PROPERTIES_CACHE = Dict{Tuple{Int,String},DeviceProperties}() + +""" + device_properties(device::AbstractDevice) + +Get a struct containing device properties. Which exact fields are populated relies on the +underlying device implementation. +""" +function device_properties(device::AbstractDevice) + pname = platform_name(client(device)) + local_hardware_id = get_local_hardware_id(device) + + if haskey(DEVICE_PROPERTIES_CACHE, (local_hardware_id, pname)) + return DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] + end + + jldevprops = Ref{DeviceProperties}() + if pname == "cuda" + GC.@preserve jldevprops begin + @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetProperties( + jldevprops::Ptr{Cvoid}, local_hardware_id::Cint + )::Cvoid + end + else + @warn "`get_properties` not implemented for platform: $(pname)" maxlog = 1 + end + DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] = jldevprops[] + return jldevprops[] +end + +function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties) + return print( + io, + """ + DeviceProperties + ---------------- + Total Global Mem: $(_format_bytes(props.total_global_mem)) + Shared Mem Per Block: $(_format_bytes(props.shared_mem_per_block)) + Regs Per Block: $(props.regs_per_block) + Warp Size: $(props.warp_size) + Max Threads Per Block: $(props.max_threads_per_block) + Max Threads Dim: $(props.max_threads_dim) + Max Grid Size: $(props.max_grid_size) + Clock Rate: $(props.clock_rate) + Total Const Mem: $(_format_bytes(props.total_const_mem)) + Version: $(VersionNumber(props.major, props.minor)) + Multi Processor Count: $(props.multi_processor_count) + Can Map Host Memory: $(props.can_map_host_memory) + Compute Mode: $(props.compute_mode) + L2 Cache Size: $(props.l2_cache_size) + Max Threads Per Multiprocessor: $(props.max_threads_per_multiprocessor) + """, + ) +end diff --git a/src/xla/IFRT/Device.jl b/src/xla/IFRT/Device.jl index 7d269e166c..dba2c98cd5 100644 --- a/src/xla/IFRT/Device.jl +++ b/src/xla/IFRT/Device.jl @@ -31,6 +31,14 @@ function XLA.get_local_device_id(::Device) return error("Not implemented for ifrt devices") end +function XLA.get_local_hardware_id(::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.ifrt_DeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.default_memory(device::Device) GC.@preserve device begin return Memory( diff --git a/src/xla/PJRT/Device.jl b/src/xla/PJRT/Device.jl index 2a29c6279b..4a4dd178e7 100644 --- a/src/xla/PJRT/Device.jl +++ b/src/xla/PJRT/Device.jl @@ -33,6 +33,14 @@ function XLA.get_local_device_id(device::Device) end end +function XLA.get_local_hardware_id(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.is_addressable(device::Device) GC.@preserve device begin return @ccall MLIR.API.mlir_c.pjrt_device_is_addressable( diff --git a/src/xla/Stats.jl b/src/xla/Stats.jl index bc66cc348a..59f62609c2 100644 --- a/src/xla/Stats.jl +++ b/src/xla/Stats.jl @@ -13,7 +13,7 @@ struct JLAllocatorStats peak_pool_bytes::Int64 end -_format_bytes(x) = Base.format_bytes(x) +_format_bytes(x) = x < 0 ? nothing : Base.format_bytes(x) _format_bytes(x::Nothing) = x """ diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 1a7ffc17f2..f14139b890 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -234,15 +234,6 @@ for runtime in (:PJRT, :IFRT) ) state.clients["cuda"] = gpu state.default_client = gpu - - # set values for cuda. This is being done here since we need cuda - # to be initialized before we can use it. initializing the devices - # implicitly initializes cuda. - cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32 - cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32 - Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)" - - Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32 catch e println(stdout, e) end From 77da391b78c8a38a94316697198cb611d0204c5c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 16:53:08 -0400 Subject: [PATCH 02/28] fix: apply suggestion from @avik-pal [skp ci] --- src/xla/IFRT/Device.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xla/IFRT/Device.jl b/src/xla/IFRT/Device.jl index dba2c98cd5..672900454a 100644 --- a/src/xla/IFRT/Device.jl +++ b/src/xla/IFRT/Device.jl @@ -31,7 +31,7 @@ function XLA.get_local_device_id(::Device) return error("Not implemented for ifrt devices") end -function XLA.get_local_hardware_id(::Device) +function XLA.get_local_hardware_id(device::Device) GC.@preserve device begin return @ccall MLIR.API.mlir_c.ifrt_DeviceGetLocalHardwareId( device.device::Ptr{Cvoid} From c80c3deb905389d58982711d7936ac73c08f0f72 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 07:58:36 -0500 Subject: [PATCH 03/28] chore: bump reactant_jll --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fbde73622a..f08ffd3b90 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.251" +Reactant_jll = "0.0.252" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" From fd9f3c956dfc6b08d35e44c5620cdf05138a88b9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 19:23:19 -0500 Subject: [PATCH 04/28] fix: remove deleted fields [skip ci] --- Project.toml | 2 +- src/xla/Device.jl | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index f08ffd3b90..fbde73622a 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.252" +Reactant_jll = "0.0.251" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/src/xla/Device.jl b/src/xla/Device.jl index 5e28cc3ce3..fd76bb6e3e 100644 --- a/src/xla/Device.jl +++ b/src/xla/Device.jl @@ -40,13 +40,11 @@ struct DeviceProperties max_threads_per_block::Cint max_threads_dim::NTuple{3,Cint} max_grid_size::NTuple{3,Cint} - clock_rate::Cint total_const_mem::Csize_t major::Cint minor::Cint multi_processor_count::Cint can_map_host_memory::Cint - compute_mode::Cint l2_cache_size::Cint max_threads_per_multiprocessor::Cint end @@ -94,12 +92,10 @@ function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties) Max Threads Per Block: $(props.max_threads_per_block) Max Threads Dim: $(props.max_threads_dim) Max Grid Size: $(props.max_grid_size) - Clock Rate: $(props.clock_rate) Total Const Mem: $(_format_bytes(props.total_const_mem)) Version: $(VersionNumber(props.major, props.minor)) Multi Processor Count: $(props.multi_processor_count) Can Map Host Memory: $(props.can_map_host_memory) - Compute Mode: $(props.compute_mode) L2 Cache Size: $(props.l2_cache_size) Max Threads Per Multiprocessor: $(props.max_threads_per_multiprocessor) """, From 33e9d68a17d83610e1eb6c51f97c608e45d21fcf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Sep 2025 22:16:04 -0500 Subject: [PATCH 05/28] feat: initial triton setup [skip ci] --- CondaPkg.toml | 1 + .../ReactantPythonCallExt.jl | 14 +++++++++++++- ext/ReactantPythonCallExt/overlays.jl | 2 +- ext/ReactantPythonCallExt/pycall.jl | 16 +++++++++++++++- 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/CondaPkg.toml b/CondaPkg.toml index b1db4f8e75..40cc769513 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4" jax = ">= 0.6" tensorflow = ">= 2.17" numpy = ">= 2" +triton = "" # TODO: version bound diff --git a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl index 1f10630808..3b4a85e951 100644 --- a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl @@ -1,6 +1,6 @@ module ReactantPythonCallExt -using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist +using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay using Reactant.Ops: @opcall @@ -9,6 +9,10 @@ const jnpptr = Ref{Py}() const JAX_TRACING_SUPPORTED = Ref{Bool}(false) +const tritonptr = Ref{Py}() + +const TRITON_COMPILE_SUPPORTED = Ref{Bool}(false) + const tfptr = Ref{Py}() const tf2xlaptr = Ref{Py}() const npptr = Ref{Py}() @@ -43,6 +47,14 @@ function __init__() be supported." exception = (err, catch_backtrace()) end + try + tritonptr[] = pyimport("triton") + TRITON_COMPILE_SUPPORTED[] = true + catch err + @warn "Failed to import triton. Compiling jax functions with triton won't be \ + supported." exception = (err, catch_backtrace()) + end + try tfptr[] = pyimport("tensorflow") tfptr[].config.set_visible_devices(pylist(); device_type="GPU") diff --git a/ext/ReactantPythonCallExt/overlays.jl b/ext/ReactantPythonCallExt/overlays.jl index 20ffa7384f..20a9210023 100644 --- a/ext/ReactantPythonCallExt/overlays.jl +++ b/ext/ReactantPythonCallExt/overlays.jl @@ -1,6 +1,6 @@ @reactant_overlay function PythonCall.pycall(f::Py, args...) if Reactant.looped_any(Reactant.use_overlayed_version, args) - return pycall_with_jax_tracing(f, args...) + return overlayed_pycall(f, args...) else return Base.inferencebarrier(PythonCall.pycall)(f, args...) end diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 23674d9155..8f81b50049 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -7,7 +7,17 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe ) end -function pycall_with_jax_tracing(f::Py, args...) +function overlayed_pycall(f::Py, args...) + @assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[] + # TODO: check for Autotuner and Heutistics as well + if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction) + return overlayed_pycall_with_triton(f, args...) + else + return overlayed_pycall_with_jax_tracing(f, args...) + end +end + +function overlayed_pycall_with_jax_tracing(f::Py, args...) JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.") seen_args = Reactant.OrderedIdDict() @@ -35,3 +45,7 @@ function pycall_with_jax_tracing(f::Py, args...) res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end + +function overlayed_pycall_with_triton(f::Py, args...) + error("TODO: implement triton") +end From 75f000ea0421921332ba4a14f6f7add8c508de4a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Sep 2025 23:30:22 -0500 Subject: [PATCH 06/28] feat: auto-trace triton code --- CondaPkg.toml | 2 +- .../ReactantPythonCallExt.jl | 26 ++++++- ext/ReactantPythonCallExt/overlays.jl | 6 +- ext/ReactantPythonCallExt/pycall.jl | 72 +++++++++++++++++-- 4 files changed, 97 insertions(+), 9 deletions(-) diff --git a/CondaPkg.toml b/CondaPkg.toml index 40cc769513..00aa12cb4a 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -5,4 +5,4 @@ python = "<=3.13,>=3.9,<4" jax = ">= 0.6" tensorflow = ">= 2.17" numpy = ">= 2" -triton = "" # TODO: version bound +triton = ">= 3.4" diff --git a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl index 3b4a85e951..af3852ce2e 100644 --- a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl @@ -1,8 +1,10 @@ module ReactantPythonCallExt -using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance +using PythonCall: + PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance, pytuple using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay using Reactant.Ops: @opcall +using Reactant_jll: Reactant_jll const jaxptr = Ref{Py}() const jnpptr = Ref{Py}() @@ -37,6 +39,28 @@ const NUMPY_SIMPLE_TYPES = Dict( ComplexF64 => :complex64, ) +const MLIR_TYPE_STRING = Dict( + Float64 => "fp64", + Float32 => "fp32", + Float16 => "fp16", + Int64 => "i64", + Int32 => "i32", + Int16 => "i16", + Int8 => "i8", + UInt64 => "ui64", + UInt32 => "ui32", + UInt16 => "ui16", + UInt8 => "ui8", + Bool => "i1", + Reactant.F8E4M3FN => "fp8e4nv", + Reactant.F8E5M2FNUZ => "fp8e5b16", + Reactant.F8E4M3FNUZ => "fp8e4b8", + Reactant.F8E5M2 => "fp8e5", +) +if isdefined(Core, :BFloat16) + MLIR_TYPE_STRING[Core.BFloat16] = "bf16" +end + function __init__() try jaxptr[] = pyimport("jax") diff --git a/ext/ReactantPythonCallExt/overlays.jl b/ext/ReactantPythonCallExt/overlays.jl index 20a9210023..ca5bcfcea5 100644 --- a/ext/ReactantPythonCallExt/overlays.jl +++ b/ext/ReactantPythonCallExt/overlays.jl @@ -1,7 +1,7 @@ -@reactant_overlay function PythonCall.pycall(f::Py, args...) +@reactant_overlay function PythonCall.pycall(f::Py, args...; kwargs...) if Reactant.looped_any(Reactant.use_overlayed_version, args) - return overlayed_pycall(f, args...) + return overlayed_pycall(f, args...; kwargs...) else - return Base.inferencebarrier(PythonCall.pycall)(f, args...) + return Base.inferencebarrier(PythonCall.pycall)(f, args...; kwargs...) end end diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 8f81b50049..7786c1b73a 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -7,12 +7,13 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe ) end -function overlayed_pycall(f::Py, args...) +function overlayed_pycall(f::Py, args...; kwargs...) @assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[] # TODO: check for Autotuner and Heutistics as well if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction) - return overlayed_pycall_with_triton(f, args...) + return overlayed_pycall_with_triton(f, args...; kwargs...) else + @assert isempty(kwargs) "`kwargs` are not supported for jax traced functions." return overlayed_pycall_with_jax_tracing(f, args...) end end @@ -46,6 +47,69 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end -function overlayed_pycall_with_triton(f::Py, args...) - error("TODO: implement triton") +# TODO: support using metaparams here +normalize_grid(grid::Integer) = normalize_grid((grid,)) +function normalize_grid(grid::Dims{N}) where {N} + @assert N <= 3 + @assert all(grid .> 0) + return (grid..., ntuple(_ -> 1, 3 - N)...) +end + +signature_string(::TracedRArray{T}) where {T} = "*$(MLIR_TYPE_STRING[T])", nothing +signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothing +signature_string(x::T) where {T<:Number} = string(x), x +signature_string(x) = error("Unsupported argument type: $(typeof(x))") + +function overlayed_pycall_with_triton( + kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing +) + triton = tritonptr[] + + grid = normalize_grid(grid) + + mapped = map(signature_string, args) + signature = first.(mapped) + # TODO: are hints actually correctly set? + hints = + hints === nothing ? Dict() : Dict(kernel.arg_names[i - 1] => v for (i, v) in hints) + constants = Dict( + kernel.arg_names[i - 1] => constant for + (i, constant) in enumerate(last.(mapped)) if constant !== nothing + ) + for (k, v) in hints + v == 1 && (constants[kernel.arg_names[k - 1]] = v) + end + attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16) + + sigmap = Dict(kernel.arg_names[i - 1] => sig for (i, sig) in enumerate(signature)) + for k in keys(constants) + sigmap[k] = "constexpr" + end + + for h in values(hints) + @assert h in (1, 16) "Only 1 and 16 are valid hints, got $h" + end + attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16) + + src = triton.compiler.ASTSource(; + fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs + ) + + # TODO: check that we are using CUDA. Get compute_capability from the target + target = triton.backends.compiler.GPUTarget("cuda", 80, 32) + backend = triton.compiler.make_backend(target) + options = backend.parse_options( + pydict( + "num_warps" => num_warps, + "num_stages" => num_stages, + "extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)), + ), + ) + + ccinfo = triton.compile(src; target=target, options=options.__dict__) + + println(pyconvert(String, ccinfo.asm["source"])) + println(pyconvert(String, ccinfo.asm["ttir"])) + + return error("TODO: implement triton") end From 66d3580d0c50b59894dd3f83c71bf9a6df2c4b7d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 25 Sep 2025 00:04:41 -0500 Subject: [PATCH 07/28] feat: copy tt.func into main module [skip ci] --- ext/ReactantPythonCallExt/pycall.jl | 14 ++++- src/Ops.jl | 98 +++++++++++++++++------------ 2 files changed, 70 insertions(+), 42 deletions(-) diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 7786c1b73a..c1c5662a67 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -108,8 +108,18 @@ function overlayed_pycall_with_triton( ccinfo = triton.compile(src; target=target, options=options.__dict__) - println(pyconvert(String, ccinfo.asm["source"])) - println(pyconvert(String, ccinfo.asm["ttir"])) + @show ccinfo.metadata + @show ccinfo.asm.keys() + # shared = ccinfo.metadata["shared"] + kernel_name = pyconvert(String, ccinfo.metadata.name) + # cluster_dims = ccinfo.metadata["cluster_dims"] + + # println(pyconvert(String, ccinfo.asm["source"])) + # println(pyconvert(String, ccinfo.asm["ttir"])) + + res = @opcall triton_call( + pyconvert(String, ccinfo.asm["ttir"]), args...; func_name=kernel_name + ) return error("TODO: implement triton") end diff --git a/src/Ops.jl b/src/Ops.jl index 252d87dc84..2c79753652 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1743,51 +1743,22 @@ end end # Generate a unique name given a module hash and a function name. -function _hlo_call_name(orig_name, module_suffix) - return orig_name * "_hlo_call_" * module_suffix -end +_new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix -""" - hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray} - -Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main") -with the provided arguments and return a tuple for each result of the call. - -```julia-repl -julia> Reactant.@jit( - hlo_call( - \"\"\" - module { - func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { - %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> - return %0 : tensor<3xf32> - } - } - \"\"\", - Reactant.to_rarray(Float32[1, 2, 3]), - Reactant.to_rarray(Float32[1, 2, 3]), - ) - ) -(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) -``` -""" -@noinline function hlo_call( - code, - args...; - func_name="main", - location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), +function _extract_function( + code::String; func_name::String="main", func_op_kind::String="func.func" ) module_suffix = string(hash(code); base=16) - name_to_call = _hlo_call_name(func_name, module_suffix) + name_to_call = _new_function_name(func_name, module_suffix) current_module = MLIR.IR.mmodule() top_level_block = MLIR.IR.body(current_module) symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - fn = MLIR.IR.lookup( MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call ) + if isnothing(fn) new_mod = parse(MLIR.IR.Module, code) new_mod_op = MLIR.IR.Operation(new_mod) @@ -1795,16 +1766,15 @@ julia> Reactant.@jit( operations = collect(MLIR.IR.OperationIterator(body)) for op in operations - if MLIR.IR.name(op) == "func.func" + if MLIR.IR.name(op) == func_op_kind fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) if fn_name == func_name fn = op end - new_name = _hlo_call_name(fn_name, module_suffix) res = MLIR.IR.LogicalResult( MLIR.API.mlirSymbolTableReplaceAllSymbolUses( - fn_name, new_name, new_mod_op + fn_name, name_to_call, new_mod_op ), ) @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" @@ -1817,7 +1787,7 @@ julia> Reactant.@jit( ) # Change function name - MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name)) + MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) end end @@ -1831,11 +1801,59 @@ julia> Reactant.@jit( error("hlo_call: could not find function $func_name in the provided module") end + return name_to_call +end + +function triton_call( + mlir_code::String, + args::Union{TracedRArray,TracedRNumber,Number}...; + func_name::String="main", + location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), +) + name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") + + @show name_to_call + display(MLIR.IR.mmodule()) + + error("TODO: implement triton_call") +end + +""" + hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray} + +Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main") +with the provided arguments and return a tuple for each result of the call. + +```julia-repl +julia> Reactant.@jit( + hlo_call( + \"\"\" + module { + func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } + } + \"\"\", + Reactant.to_rarray(Float32[1, 2, 3]), + Reactant.to_rarray(Float32[1, 2, 3]), + ) + ) +(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) +``` +""" +@noinline function hlo_call( + code, + args::Union{TracedRArray,TracedRNumber}...; + func_name="main", + location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), +) + name_to_call = _extract_function(code; func_name, func_op_kind="func.func") + ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) - @assert all(Base.Fix2(isa, Union{TracedRArray,TracedRNumber}), args) "hlo_call: all inputs to hlo_call should be reactant arrays or numbers" - @assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name" + @assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name. Expected $(MLIR.IR.ninputs(ftype)), got $(length(args))" for (i, arg) in enumerate(args) expected_type = MLIR.IR.input(ftype, i) From 9447d0b58edcfb573ce847d228db3dd695cdae74 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Sep 2025 08:56:21 -0500 Subject: [PATCH 08/28] feat: tracing fully functional --- ext/ReactantPythonCallExt/pycall.jl | 23 +++++++++++------------ src/Ops.jl | 19 ++++++++++++++++--- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index c1c5662a67..4c9a8cba82 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -106,20 +106,19 @@ function overlayed_pycall_with_triton( ), ) + # Currently we are doing a double compilation here. can we do better? + # we are compiling here + lowering again inside enzymejax ccinfo = triton.compile(src; target=target, options=options.__dict__) - @show ccinfo.metadata - @show ccinfo.asm.keys() - # shared = ccinfo.metadata["shared"] - kernel_name = pyconvert(String, ccinfo.metadata.name) - # cluster_dims = ccinfo.metadata["cluster_dims"] - - # println(pyconvert(String, ccinfo.asm["source"])) - # println(pyconvert(String, ccinfo.asm["ttir"])) - - res = @opcall triton_call( - pyconvert(String, ccinfo.asm["ttir"]), args...; func_name=kernel_name + @opcall triton_call( + pyconvert(String, ccinfo.asm["ttir"]), + filter(x -> x isa Reactant.TracedType, args)...; + func_name=pyconvert(String, ccinfo.metadata.name), + grid_x=@opcall(constant(grid[1])), + grid_y=@opcall(constant(grid[2])), + grid_z=@opcall(constant(grid[3])), + shmem=@opcall(constant(pyconvert(Int, ccinfo.metadata.shared))), ) - return error("TODO: implement triton") + return nothing end diff --git a/src/Ops.jl b/src/Ops.jl index 2c79753652..db730361e0 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1808,14 +1808,27 @@ function triton_call( mlir_code::String, args::Union{TracedRArray,TracedRNumber,Number}...; func_name::String="main", + grid_x::TracedRNumber{<:Integer}, + grid_y::TracedRNumber{<:Integer}, + grid_z::TracedRNumber{<:Integer}, + shmem::TracedRNumber{<:Integer}, location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), + # TODO: other kwargs ) name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") - @show name_to_call - display(MLIR.IR.mmodule()) + enzymexla.triton_call( + grid_x.mlir_data, + grid_y.mlir_data, + grid_z.mlir_data, + shmem.mlir_data, + [Reactant.TracedUtils.get_mlir_data(a) for a in args]; + fn=MLIR.IR.FlatSymbolRefAttribute(name_to_call), + result_0=MLIR.IR.Type[], + location, + ) - error("TODO: implement triton_call") + return nothing end """ From 05003974fd26915493fee60ee28e122efeaa31a3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Sep 2025 10:11:26 -0500 Subject: [PATCH 09/28] fix: hlo_call --- src/Ops.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index db730361e0..b82b5fb171 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1801,7 +1801,7 @@ function _extract_function( error("hlo_call: could not find function $func_name in the provided module") end - return name_to_call + return fn, name_to_call end function triton_call( @@ -1815,7 +1815,7 @@ function triton_call( location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) - name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") + _, name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") enzymexla.triton_call( grid_x.mlir_data, @@ -1861,7 +1861,7 @@ julia> Reactant.@jit( func_name="main", location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), ) - name_to_call = _extract_function(code; func_name, func_op_kind="func.func") + fn, name_to_call = _extract_function(code; func_name, func_op_kind="func.func") ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) From 532a302f8f2616f1fd12c3b27b931a163a2af609 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Sep 2025 23:19:56 -0500 Subject: [PATCH 10/28] feat: more triton passes + keep triton func in a separate module --- deps/ReactantExtra/BUILD | 3 ++ ext/ReactantPythonCallExt/pycall.jl | 10 +++-- src/Compiler.jl | 60 ++++++++++++++++++++++++++++- src/Ops.jl | 14 ++++++- src/mlir/IR/Module.jl | 3 +- 5 files changed, 83 insertions(+), 7 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 4f3ce35cfe..920208e194 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -979,6 +979,9 @@ cc_library( "-Wl,-exported_symbol,_ReactantFuncSetArgAttr", "-Wl,-exported_symbol,_ReactantHermeticCudaGetVersion", "-Wl,-exported_symbol,_ReactantCudaDriverGetVersion", + "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor", + "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor", + "-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads", "-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions", "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor", "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor", diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 4c9a8cba82..40026af81f 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -60,6 +60,7 @@ signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothi signature_string(x::T) where {T<:Number} = string(x), x signature_string(x) = error("Unsupported argument type: $(typeof(x))") +# TODO: better name for hints? function overlayed_pycall_with_triton( kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing ) @@ -95,8 +96,11 @@ function overlayed_pycall_with_triton( fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs ) - # TODO: check that we are using CUDA. Get compute_capability from the target - target = triton.backends.compiler.GPUTarget("cuda", 80, 32) + target = triton.backends.compiler.GPUTarget( + "cuda", + parse(Int, Reactant.Compiler.cubinChip[][4:end]), + Reactant.Compiler.cuWarpSize[], + ) backend = triton.compiler.make_backend(target) options = backend.parse_options( pydict( @@ -111,7 +115,7 @@ function overlayed_pycall_with_triton( ccinfo = triton.compile(src; target=target, options=options.__dict__) @opcall triton_call( - pyconvert(String, ccinfo.asm["ttir"]), + pyconvert(String, ccinfo.asm["source"]), filter(x -> x isa Reactant.TracedType, args)...; func_name=pyconvert(String, ccinfo.metadata.name), grid_x=@opcall(constant(grid[1])), diff --git a/src/Compiler.jl b/src/Compiler.jl index 5823c9a07b..78b495660f 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1300,9 +1300,66 @@ function optimization_passes( push!(passes, "remove-duplicate-func-def") end push!(passes, func_passes) + if backend == "cuda" + push!(passes, triton_optimization_passes()) + end return join(passes, ',') end +# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc +function triton_optimization_passes() + # TODO: check that all triton passes are included here + return join( + [ + # convert passes + "convert-scf-to-cf", + "convert-cf-to-llvm", + "convert-index-to-llvm", + "convert-arith-to-llvm", + "convert-nvvm-to-llvm", + # common passes + "canonicalize", + # # ttir passes + # "triton-combine", + # "triton-reorder-broadcast", + # "triton-rewrite-tensor-pointer", + # "triton-rewrite-tensor-descriptor-to-pointer", + # "triton-loop-unroll", + # "triton-licm", + # "triton-loop-aware-cse", + # # TODO: should num-warps and num-ctas be set for each kernel? + # "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}", + # # ttgir passes + # "tritongpu-coalesce", + # "tritongpu-optimize-thread-locality", + # "tritongpu-hoist-tmem-alloc", + # "tritongpu-assign-latencies", + # "tritongpu-pipeline", + # "tritongpu-schedule-loops", + # "tritongpu-automatic-warp-specialization", + # "tritongpu-prefetch", + # "tritongpu-accelerate-matmul", + # "tritongpu-reorder-instructions", + # "tritongpu-F32DotTC", + # "tritongpu-optimize-dot-operands", + # "tritongpu-remove-layout-conversions", + # "tritongpu-reduce-data-duplication", + # "tritongpu-hoist-tmem-alloc", + # "tritongpu-fuse-nested-loops", + # "tritongpu-rewrite-partition-dependencies", + # "tritongpu-partition-loops", + # "tritongpu-combine-tensor-select-and-if", + # # ttgir to llvm passes + # "tritongpu-allocate-warp-groups", + # "allocate-shared-memory", + # "tritongpu-global-scratch-memory-allocation", + # "tritongpu-optimize-accumulator-init", + # "tritongpu-coalesce-async-copy", + ], + ",", + ) +end + # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" @@ -2261,7 +2318,8 @@ function compile_mlir!( end end - run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects") + # XXX: re-enable this pass + # run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects") func_op = MLIR.API.mlirSymbolTableLookup( MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname diff --git a/src/Ops.jl b/src/Ops.jl index b82b5fb171..d2ad70a9c7 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1746,12 +1746,20 @@ end _new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix function _extract_function( - code::String; func_name::String="main", func_op_kind::String="func.func" + code::String; + func_name::String="main", + func_op_kind::String="func.func", + nested_module::Bool=false, ) module_suffix = string(hash(code); base=16) name_to_call = _new_function_name(func_name, module_suffix) current_module = MLIR.IR.mmodule() + if nested_module + new_module = MLIR.IR.Module() + push!(MLIR.IR.body(current_module), MLIR.IR.Operation(new_module, true)) + current_module = new_module + end top_level_block = MLIR.IR.body(current_module) symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) @@ -1815,7 +1823,9 @@ function triton_call( location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) - _, name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") + _, name_to_call = _extract_function( + mlir_code; func_name, func_op_kind="tt.func", nested_module=true + ) enzymexla.triton_call( grid_x.mlir_data, diff --git a/src/mlir/IR/Module.jl b/src/mlir/IR/Module.jl index 12794b30ba..c7f938d5b8 100644 --- a/src/mlir/IR/Module.jl +++ b/src/mlir/IR/Module.jl @@ -52,7 +52,8 @@ body(module_) = Block(API.mlirModuleGetBody(module_), false) Views the module as a generic operation. """ -Operation(module_::Module) = Operation(API.mlirModuleGetOperation(module_), false) +Operation(module_::Module, owned::Bool=false) = + Operation(API.mlirModuleGetOperation(module_), owned) function Base.show(io::IO, module_::Module) return show(io, Operation(module_)) From e1d9fc0709d0d5885c94de9e71e979637a94cde4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Sep 2025 00:14:35 -0500 Subject: [PATCH 11/28] feat: put the tt func in a separate module and use symbol ref --- src/Compiler.jl | 75 ++++++++++++++++++++--------------------- src/Ops.jl | 90 +++++++++++++++++++++++++++---------------------- 2 files changed, 87 insertions(+), 78 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 78b495660f..5059acd94c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1319,42 +1319,42 @@ function triton_optimization_passes() "convert-nvvm-to-llvm", # common passes "canonicalize", - # # ttir passes - # "triton-combine", - # "triton-reorder-broadcast", - # "triton-rewrite-tensor-pointer", - # "triton-rewrite-tensor-descriptor-to-pointer", - # "triton-loop-unroll", - # "triton-licm", - # "triton-loop-aware-cse", - # # TODO: should num-warps and num-ctas be set for each kernel? - # "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}", - # # ttgir passes - # "tritongpu-coalesce", - # "tritongpu-optimize-thread-locality", - # "tritongpu-hoist-tmem-alloc", - # "tritongpu-assign-latencies", - # "tritongpu-pipeline", - # "tritongpu-schedule-loops", - # "tritongpu-automatic-warp-specialization", - # "tritongpu-prefetch", - # "tritongpu-accelerate-matmul", - # "tritongpu-reorder-instructions", - # "tritongpu-F32DotTC", - # "tritongpu-optimize-dot-operands", - # "tritongpu-remove-layout-conversions", - # "tritongpu-reduce-data-duplication", - # "tritongpu-hoist-tmem-alloc", - # "tritongpu-fuse-nested-loops", - # "tritongpu-rewrite-partition-dependencies", - # "tritongpu-partition-loops", - # "tritongpu-combine-tensor-select-and-if", - # # ttgir to llvm passes - # "tritongpu-allocate-warp-groups", - # "allocate-shared-memory", - # "tritongpu-global-scratch-memory-allocation", - # "tritongpu-optimize-accumulator-init", - # "tritongpu-coalesce-async-copy", + # ttir passes + "triton-combine", + "triton-reorder-broadcast", + "triton-rewrite-tensor-pointer", + "triton-rewrite-tensor-descriptor-to-pointer", + "triton-loop-unroll", + "triton-licm", + "triton-loop-aware-cse", + # TODO: should num-warps and num-ctas be set for each kernel? + "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}", + # ttgir passes + "tritongpu-coalesce", + "tritongpu-optimize-thread-locality", + "tritongpu-hoist-tmem-alloc", + "tritongpu-assign-latencies", + "tritongpu-pipeline", + "tritongpu-schedule-loops", + "tritongpu-automatic-warp-specialization", + "tritongpu-prefetch", + "tritongpu-accelerate-matmul", + "tritongpu-reorder-instructions", + "tritongpu-F32DotTC", + "tritongpu-optimize-dot-operands", + "tritongpu-remove-layout-conversions", + "tritongpu-reduce-data-duplication", + "tritongpu-hoist-tmem-alloc", + "tritongpu-fuse-nested-loops", + "tritongpu-rewrite-partition-dependencies", + "tritongpu-partition-loops", + "tritongpu-combine-tensor-select-and-if", + # ttgir to llvm passes + "tritongpu-allocate-warp-groups", + "allocate-shared-memory", + "tritongpu-global-scratch-memory-allocation", + "tritongpu-optimize-accumulator-init", + "tritongpu-coalesce-async-copy", ], ",", ) @@ -2318,8 +2318,7 @@ function compile_mlir!( end end - # XXX: re-enable this pass - # run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects") + run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects") func_op = MLIR.API.mlirSymbolTableLookup( MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname diff --git a/src/Ops.jl b/src/Ops.jl index d2ad70a9c7..1e9f404369 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1750,22 +1750,32 @@ function _extract_function( func_name::String="main", func_op_kind::String="func.func", nested_module::Bool=false, + location::MLIR.IR.Location=MLIR.IR.Location(), ) module_suffix = string(hash(code); base=16) - name_to_call = _new_function_name(func_name, module_suffix) + name_to_call = func_name * "_call_" * module_suffix + mod_name = func_name * "_module_" * module_suffix + symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - current_module = MLIR.IR.mmodule() if nested_module - new_module = MLIR.IR.Module() - push!(MLIR.IR.body(current_module), MLIR.IR.Operation(new_module, true)) - current_module = new_module - end - top_level_block = MLIR.IR.body(current_module) + region = MLIR.IR.Region() + push!(region, MLIR.IR.Block()) + moduleop = MLIR.Dialects.builtin.module_(; + location, bodyRegion=region, sym_name=mod_name + ) + MLIR.IR.rmfromparent!(moduleop) + push!(MLIR.IR.body(MLIR.IR.mmodule()), moduleop) # insert into parent module - symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - fn = MLIR.IR.lookup( - MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call - ) + top_level_block = MLIR.IR.Block( + MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false + ) + fn = nothing + else + current_module = MLIR.IR.mmodule() + moduleop = MLIR.IR.Operation(current_module) + top_level_block = MLIR.IR.body(current_module) + fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call) + end if isnothing(fn) new_mod = parse(MLIR.IR.Module, code) @@ -1773,31 +1783,27 @@ function _extract_function( body = MLIR.IR.body(new_mod) operations = collect(MLIR.IR.OperationIterator(body)) - for op in operations - if MLIR.IR.name(op) == func_op_kind - fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) - if fn_name == func_name - fn = op - end + idx = Base.findfirst(op -> MLIR.IR.name(op) == func_op_kind, operations) + @assert idx !== nothing + op = operations[idx] - res = MLIR.IR.LogicalResult( - MLIR.API.mlirSymbolTableReplaceAllSymbolUses( - fn_name, name_to_call, new_mod_op - ), - ) - @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" - - # Set function private - MLIR.IR.attr!( - op, - MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), - MLIR.IR.Attribute("private"), - ) - - # Change function name - MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) - end - end + fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) + fn_name == func_name && (fn = op) + + res = MLIR.IR.LogicalResult( + MLIR.API.mlirSymbolTableReplaceAllSymbolUses(fn_name, name_to_call, new_mod_op) + ) + @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" + + # Set function private + MLIR.IR.attr!( + op, + MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), + MLIR.IR.Attribute("private"), + ) + + # Change function name + MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) for op in operations MLIR.IR.rmfromparent!(op) @@ -1809,7 +1815,7 @@ function _extract_function( error("hlo_call: could not find function $func_name in the provided module") end - return fn, name_to_call + return fn, name_to_call, mod_name end function triton_call( @@ -1823,8 +1829,8 @@ function triton_call( location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) - _, name_to_call = _extract_function( - mlir_code; func_name, func_op_kind="tt.func", nested_module=true + _, name_to_call, mod_name = _extract_function( + mlir_code; func_name, func_op_kind="tt.func", nested_module=true, location ) enzymexla.triton_call( @@ -1833,7 +1839,9 @@ function triton_call( grid_z.mlir_data, shmem.mlir_data, [Reactant.TracedUtils.get_mlir_data(a) for a in args]; - fn=MLIR.IR.FlatSymbolRefAttribute(name_to_call), + fn=MLIR.IR.SymbolRefAttribute( + mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)] + ), result_0=MLIR.IR.Type[], location, ) @@ -1871,7 +1879,9 @@ julia> Reactant.@jit( func_name="main", location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), ) - fn, name_to_call = _extract_function(code; func_name, func_op_kind="func.func") + fn, name_to_call, _ = _extract_function( + code; func_name, func_op_kind="func.func", location + ) ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) From fa02d6ec2b7299f2ecf1ca53ae12d6f6bcd6dfbf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 13:07:04 -0500 Subject: [PATCH 12/28] feat: new triton_ext dialect --- deps/ReactantExtra/BUILD | 18 +++++++ deps/ReactantExtra/make-bindings.jl | 1 + src/mlir/Dialects/TritonExt.jl | 82 +++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+) create mode 100644 src/mlir/Dialects/TritonExt.jl diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 920208e194..5d8ba5eda4 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -1439,6 +1439,24 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "TritonExtJLIncGen", + tbl_outs = [ + ( + [ + "--generator=jl-op-defs", + "--disable-module-wrap=0", + ], + "TritonExt.jl", + ), + ], + tblgen = "//:mlir-jl-tblgen", + td_file = "@enzyme_ad//src/enzyme_ad/jax:Dialect/TritonExt/Ops.td", + deps = [ + "@enzyme_ad//src/enzyme_ad/jax:TritonExtDialectTdFiles", + ], +) + gentbl_cc_library( name = "TPUJLIncGen", tbl_outs = [ diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index f84309fef1..ebdb7cd9b0 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -42,6 +42,7 @@ for file in [ "MPI.jl", "MemRef.jl", "SparseTensor.jl", + "TritonExt.jl" ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/src/mlir/Dialects/TritonExt.jl b/src/mlir/Dialects/TritonExt.jl new file mode 100644 index 0000000000..f59822b909 --- /dev/null +++ b/src/mlir/Dialects/TritonExt.jl @@ -0,0 +1,82 @@ +module triton_ext +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +function call( + gridx::Value, + gridy::Value, + gridz::Value, + shmem::Value, + inputs::Vector{Value}; + result_0::Vector{IR.Type}, + fn, + backend_config=nothing, + operand_layouts=nothing, + result_layouts=nothing, + arg_attrs=nothing, + res_attrs=nothing, + output_operand_aliases=nothing, + xla_side_effect_free=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[gridx, gridy, gridz, shmem, inputs...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(backend_config) && + push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && + push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && + push!(attributes, namedattribute("result_layouts", result_layouts)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(output_operand_aliases) && + push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && + push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + return create_operation( + "triton_ext.call", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function module_(; sym_name, bodyRegion::Region, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[bodyRegion,] + successors = Block[] + attributes = NamedAttribute[namedattribute("sym_name", sym_name),] + + return create_operation( + "triton_ext.module", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # triton_ext From a3b8cb65b1b264193b9a0d9bcc36c75039850d74 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 15:02:27 -0500 Subject: [PATCH 13/28] feat: triton tracing works now finally --- docs/src/.vitepress/config.mts | 2 + docs/src/api/dialects/tritonext.md | 11 ++++ src/Compiler.jl | 86 ++++++++++++++++++++---------- src/Ops.jl | 60 +++++++++++++-------- 4 files changed, 110 insertions(+), 49 deletions(-) create mode 100644 docs/src/api/dialects/tritonext.md diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index dacce466fb..2853e7abd5 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -131,6 +131,7 @@ export default defineConfig({ { text: "SparseTensor", link: "/api/dialects/sparsetensor" }, { text: "StableHLO", link: "/api/dialects/stablehlo" }, { text: "Triton", link: "/api/dialects/triton" }, + { text: "TritonExt", link: "/api/dialects/tritonext" }, { text: "TPU", link: "/api/dialects/tpu" }, { text: "VHLO", link: "/api/dialects/vhlo" }, ], @@ -221,6 +222,7 @@ export default defineConfig({ { text: "SparseTensor", link: "/api/dialects/sparsetensor" }, { text: "StableHLO", link: "/api/dialects/stablehlo" }, { text: "Triton", link: "/api/dialects/triton" }, + { text: "TritonExt", link: "/api/dialects/tritonext" }, { text: "TPU", link: "/api/dialects/tpu" }, { text: "VHLO", link: "/api/dialects/vhlo" }, ], diff --git a/docs/src/api/dialects/tritonext.md b/docs/src/api/dialects/tritonext.md new file mode 100644 index 0000000000..a727f0dfbb --- /dev/null +++ b/docs/src/api/dialects/tritonext.md @@ -0,0 +1,11 @@ +```@meta +CollapsedDocStrings = true +``` + +# TritonExt Dialect + +Provides extensions to the Triton dialect. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.triton_ext] +``` diff --git a/src/Compiler.jl b/src/Compiler.jl index 5059acd94c..bfc3382f3c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1307,57 +1307,89 @@ function optimization_passes( end # https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc +# To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes function triton_optimization_passes() - # TODO: check that all triton passes are included here - return join( + all_passes = join( [ - # convert passes - "convert-scf-to-cf", - "convert-cf-to-llvm", - "convert-index-to-llvm", - "convert-arith-to-llvm", - "convert-nvvm-to-llvm", - # common passes "canonicalize", - # ttir passes + "triton-rewrite-tensor-pointer", + "canonicalize", "triton-combine", "triton-reorder-broadcast", - "triton-rewrite-tensor-pointer", - "triton-rewrite-tensor-descriptor-to-pointer", + "cse", + "symbol-dce", "triton-loop-unroll", - "triton-licm", - "triton-loop-aware-cse", - # TODO: should num-warps and num-ctas be set for each kernel? "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}", - # ttgir passes "tritongpu-coalesce", + "tritongpu-F32DotTC", + "triton-nvidia-gpu-plan-cta", + "tritongpu-remove-layout-conversions", "tritongpu-optimize-thread-locality", + "tritongpu-accelerate-matmul", + "tritongpu-remove-layout-conversions", + "tritongpu-optimize-dot-operands", + "canonicalize", + "triton-nvidia-optimize-descriptor-encoding", + "triton-loop-aware-cse", + "tritongpu-fuse-nested-loops", + "canonicalize", + "triton-licm", + "tritongpu-optimize-accumulator-init", "tritongpu-hoist-tmem-alloc", + "tritongpu-promote-lhs-to-tmem", "tritongpu-assign-latencies", - "tritongpu-pipeline", "tritongpu-schedule-loops", "tritongpu-automatic-warp-specialization", + "tritongpu-partition-scheduling", + "tritongpu-load-mma-specialization", + "tritongpu-rewrite-partition-dependencies", + "sccp", + "cse", + "tritongpu-partition-loops", + "tritongpu-optimize-partition-warps", + "tritongpu-schedule-loops", + "tritongpu-pipeline", + "tritongpu-combine-tensor-select-and-if", + "triton-nvidia-gpu-remove-tmem-tokens", + "canonicalize", + "triton-loop-aware-cse", "tritongpu-prefetch", - "tritongpu-accelerate-matmul", - "tritongpu-reorder-instructions", - "tritongpu-F32DotTC", "tritongpu-optimize-dot-operands", + "canonicalize", + "tritongpu-coalesce-async-copy", + "triton-nvidia-optimize-tmem-layouts", "tritongpu-remove-layout-conversions", + "triton-nvidia-interleave-tmem", "tritongpu-reduce-data-duplication", - "tritongpu-hoist-tmem-alloc", - "tritongpu-fuse-nested-loops", - "tritongpu-rewrite-partition-dependencies", - "tritongpu-partition-loops", + "tritongpu-reorder-instructions", + "triton-loop-aware-cse", + "symbol-dce", + "triton-nvidia-tma-lowering", + "triton-nvidia-gpu-fence-insertion", + "sccp", + "canonicalize", + "triton-nvidia-mma-lowering", "tritongpu-combine-tensor-select-and-if", - # ttgir to llvm passes "tritongpu-allocate-warp-groups", + "convert-scf-to-cf", "allocate-shared-memory", + "triton-tensor-memory-allocation", "tritongpu-global-scratch-memory-allocation", - "tritongpu-optimize-accumulator-init", - "tritongpu-coalesce-async-copy", + # TODO: register the commented out passes + # "convert-triton-gpu-to-llvm", + "canonicalize", + "cse", + # "convert-nv-gpu-to-llvm", + # "convert-warp-specialize-to-llvm", + "reconcile-unrealized-casts", + "canonicalize", + "cse", + "symbol-dce", + "enable-line-info", ], ",", ) + return "triton_ext.module(builtin.module($(all_passes)))" end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate diff --git a/src/Ops.jl b/src/Ops.jl index 1e9f404369..1ce18701ba 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3,7 +3,7 @@ # Julia and Reactant semantics should be considered on the higher abstractions that use these module Ops using ..MLIR: MLIR -using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla +using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla, triton_ext using ..Reactant: Reactant, TracedRArray, @@ -1749,7 +1749,6 @@ function _extract_function( code::String; func_name::String="main", func_op_kind::String="func.func", - nested_module::Bool=false, location::MLIR.IR.Location=MLIR.IR.Location(), ) module_suffix = string(hash(code); base=16) @@ -1757,24 +1756,45 @@ function _extract_function( mod_name = func_name * "_module_" * module_suffix symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - if nested_module + use_ttext_module = split(func_op_kind, ".")[1] == "tt" + + if use_ttext_module + tt_mod_name = func_name * "_tt_module_" * module_suffix + tt_region = MLIR.IR.Region() + tt_block = MLIR.IR.Block() + push!(tt_region, tt_block) + triton_mod_op = triton_ext.module_(; + location, bodyRegion=tt_region, sym_name=tt_mod_name + ) + MLIR.IR.rmfromparent!(triton_mod_op) + push!(MLIR.IR.body(MLIR.IR.mmodule()), triton_mod_op) # insert into parent module + region = MLIR.IR.Region() push!(region, MLIR.IR.Block()) moduleop = MLIR.Dialects.builtin.module_(; location, bodyRegion=region, sym_name=mod_name ) MLIR.IR.rmfromparent!(moduleop) - push!(MLIR.IR.body(MLIR.IR.mmodule()), moduleop) # insert into parent module + push!(tt_block, moduleop) # insert into triton module top_level_block = MLIR.IR.Block( MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false ) fn = nothing + + symref = MLIR.IR.SymbolRefAttribute( + tt_mod_name, + MLIR.IR.Attribute[ + MLIR.IR.FlatSymbolRefAttribute(mod_name), + MLIR.IR.FlatSymbolRefAttribute(name_to_call), + ], + ) else current_module = MLIR.IR.mmodule() moduleop = MLIR.IR.Operation(current_module) top_level_block = MLIR.IR.body(current_module) fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call) + symref = MLIR.IR.FlatSymbolRefAttribute(name_to_call) end if isnothing(fn) @@ -1795,12 +1815,14 @@ function _extract_function( ) @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" - # Set function private - MLIR.IR.attr!( - op, - MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), - MLIR.IR.Attribute("private"), - ) + if !use_ttext_module + # Set function private + MLIR.IR.attr!( + op, + MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), + MLIR.IR.Attribute("private"), + ) + end # Change function name MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) @@ -1815,7 +1837,7 @@ function _extract_function( error("hlo_call: could not find function $func_name in the provided module") end - return fn, name_to_call, mod_name + return fn, symref end function triton_call( @@ -1829,19 +1851,15 @@ function triton_call( location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) - _, name_to_call, mod_name = _extract_function( - mlir_code; func_name, func_op_kind="tt.func", nested_module=true, location - ) + _, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location) - enzymexla.triton_call( + triton_ext.call( grid_x.mlir_data, grid_y.mlir_data, grid_z.mlir_data, shmem.mlir_data, [Reactant.TracedUtils.get_mlir_data(a) for a in args]; - fn=MLIR.IR.SymbolRefAttribute( - mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)] - ), + fn=symref, result_0=MLIR.IR.Type[], location, ) @@ -1879,9 +1897,7 @@ julia> Reactant.@jit( func_name="main", location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), ) - fn, name_to_call, _ = _extract_function( - code; func_name, func_op_kind="func.func", location - ) + fn, symref = _extract_function(code; func_name, func_op_kind="func.func", location) ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) @@ -1898,7 +1914,7 @@ julia> Reactant.@jit( call = MLIR.Dialects.func.call( operands; result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], - callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call), + callee=symref, location, ) From 97c952b1ae3fac8b02589db1880fe98cabd551b9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 16:10:43 -0500 Subject: [PATCH 14/28] fix: kind of working --- deps/ReactantExtra/make-bindings.jl | 2 +- src/Compiler.jl | 70 ++++++++++++++++++----------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index ebdb7cd9b0..9e4295e9cb 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -42,7 +42,7 @@ for file in [ "MPI.jl", "MemRef.jl", "SparseTensor.jl", - "TritonExt.jl" + "TritonExt.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/src/Compiler.jl b/src/Compiler.jl index bfc3382f3c..f36d7bb801 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -702,6 +702,7 @@ function optimization_passes( lower_comms::Bool=true, max_constant_threshold::Int=1024, backend::String="gpu", + enable_triton_passes::Bool=false, ) transform_passes_list = [ "patterns=compare_op_canon<16>", @@ -1300,7 +1301,7 @@ function optimization_passes( push!(passes, "remove-duplicate-func-def") end push!(passes, func_passes) - if backend == "cuda" + if enable_triton_passes && backend == "cuda" push!(passes, triton_optimization_passes()) end return join(passes, ',') @@ -1375,12 +1376,11 @@ function triton_optimization_passes() "allocate-shared-memory", "triton-tensor-memory-allocation", "tritongpu-global-scratch-memory-allocation", - # TODO: register the commented out passes - # "convert-triton-gpu-to-llvm", + "convert-triton-gpu-to-llvm", "canonicalize", "cse", - # "convert-nv-gpu-to-llvm", - # "convert-warp-specialize-to-llvm", + "convert-nv-gpu-to-llvm", + "convert-warp-specialize-to-llvm", "reconcile-unrealized-casts", "canonicalize", "cse", @@ -1781,10 +1781,28 @@ function compile_mlir!( end opt_passes = optimization_passes( - compile_options; sroa=true, recognize_comms, lower_comms, backend + compile_options; + sroa=true, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=false, ) opt_passes2 = optimization_passes( - compile_options; sroa=false, recognize_comms, lower_comms, backend + compile_options; + sroa=false, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=false, + ) + opt_passes3 = optimization_passes( + compile_options; + sroa=false, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=true, ) raise_passes = if raise isa String @@ -1799,7 +1817,7 @@ function compile_mlir!( opt_passes2 if DUS_TO_CONCAT[] - opt_passes3 = optimization_passes( + opt_passes_dus_to_concat = optimization_passes( compile_options; sroa=false, dus_to_concat=true, @@ -1807,7 +1825,7 @@ function compile_mlir!( lower_comms, backend, ) - result = result * "," * opt_passes3 + result = result * "," * opt_passes_dus_to_concat end result else @@ -1838,12 +1856,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, lower_enzymexla_linalg_pass, jit, ] @@ -1854,12 +1872,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, kern, raise_passes, lower_enzymexla_linalg_pass, @@ -1883,12 +1901,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, ] end, ',', @@ -1908,12 +1926,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, ] else [ @@ -1922,12 +1940,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, kern, raise_passes, ] @@ -1949,12 +1967,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, kern, ] end, @@ -1972,12 +1990,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, ], ',', ), @@ -2014,7 +2032,7 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, lower_enzymexla_linalg_pass, jit, ] @@ -2027,7 +2045,7 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, kern, raise_passes, lower_enzymexla_linalg_pass, @@ -2238,7 +2256,7 @@ function compile_mlir!( run_pass_pipeline!( mod, join( - [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2], + [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3], ",", ), "mid_pad_opts", From b84d51911df51574c1e9970f31d57e4231e02c2a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 20:41:07 -0500 Subject: [PATCH 15/28] fix: new API --- ext/ReactantPythonCallExt/pycall.jl | 20 +++++++++++----- src/Compiler.jl | 36 ++++++++++++++--------------- src/Ops.jl | 8 +++++-- src/mlir/Dialects/TritonExt.jl | 6 +++-- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 40026af81f..0788daef7e 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -47,9 +47,8 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end -# TODO: support using metaparams here -normalize_grid(grid::Integer) = normalize_grid((grid,)) -function normalize_grid(grid::Dims{N}) where {N} +normalize_grid_and_blocks(grid::Integer) = normalize_grid_and_blocks((grid,)) +function normalize_grid_and_blocks(grid::Dims{N}) where {N} @assert N <= 3 @assert all(grid .> 0) return (grid..., ntuple(_ -> 1, 3 - N)...) @@ -62,11 +61,18 @@ signature_string(x) = error("Unsupported argument type: $(typeof(x))") # TODO: better name for hints? function overlayed_pycall_with_triton( - kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing + kernel::Py, + args...; + grid, + blocks, + num_warps::Integer=1, + num_stages::Integer=3, + hints=nothing, ) triton = tritonptr[] - grid = normalize_grid(grid) + grid = normalize_grid_and_blocks(grid) + blocks = normalize_grid_and_blocks(blocks) mapped = map(signature_string, args) signature = first.(mapped) @@ -121,7 +127,9 @@ function overlayed_pycall_with_triton( grid_x=@opcall(constant(grid[1])), grid_y=@opcall(constant(grid[2])), grid_z=@opcall(constant(grid[3])), - shmem=@opcall(constant(pyconvert(Int, ccinfo.metadata.shared))), + block_x=@opcall(constant(blocks[1])), + block_y=@opcall(constant(blocks[2])), + block_z=@opcall(constant(blocks[3])), ) return nothing diff --git a/src/Compiler.jl b/src/Compiler.jl index f36d7bb801..5f5fca2a5b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1796,7 +1796,7 @@ function compile_mlir!( backend, enable_triton_passes=false, ) - opt_passes3 = optimization_passes( + opt_passes_with_triton = optimization_passes( compile_options; sroa=false, recognize_comms, @@ -1856,12 +1856,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, lower_enzymexla_linalg_pass, jit, ] @@ -1872,12 +1872,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, kern, raise_passes, lower_enzymexla_linalg_pass, @@ -1901,12 +1901,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, ] end, ',', @@ -1926,12 +1926,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, ] else [ @@ -1940,12 +1940,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, kern, raise_passes, ] @@ -1967,12 +1967,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, kern, ] end, @@ -1990,12 +1990,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, ], ',', ), @@ -2032,7 +2032,7 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes_with_triton, lower_enzymexla_linalg_pass, jit, ] @@ -2045,7 +2045,7 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes_with_triton, kern, raise_passes, lower_enzymexla_linalg_pass, @@ -2256,7 +2256,7 @@ function compile_mlir!( run_pass_pipeline!( mod, join( - [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3], + [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2], ",", ), "mid_pad_opts", diff --git a/src/Ops.jl b/src/Ops.jl index 1ce18701ba..3f81e228f2 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1847,7 +1847,9 @@ function triton_call( grid_x::TracedRNumber{<:Integer}, grid_y::TracedRNumber{<:Integer}, grid_z::TracedRNumber{<:Integer}, - shmem::TracedRNumber{<:Integer}, + block_x::TracedRNumber{<:Integer}, + block_y::TracedRNumber{<:Integer}, + block_z::TracedRNumber{<:Integer}, location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) @@ -1857,7 +1859,9 @@ function triton_call( grid_x.mlir_data, grid_y.mlir_data, grid_z.mlir_data, - shmem.mlir_data, + block_x.mlir_data, + block_y.mlir_data, + block_z.mlir_data, [Reactant.TracedUtils.get_mlir_data(a) for a in args]; fn=symref, result_0=MLIR.IR.Type[], diff --git a/src/mlir/Dialects/TritonExt.jl b/src/mlir/Dialects/TritonExt.jl index f59822b909..bb79bade44 100644 --- a/src/mlir/Dialects/TritonExt.jl +++ b/src/mlir/Dialects/TritonExt.jl @@ -17,7 +17,9 @@ function call( gridx::Value, gridy::Value, gridz::Value, - shmem::Value, + blockx::Value, + blocky::Value, + blockz::Value, inputs::Vector{Value}; result_0::Vector{IR.Type}, fn, @@ -31,7 +33,7 @@ function call( location=Location(), ) op_ty_results = IR.Type[result_0...,] - operands = Value[gridx, gridy, gridz, shmem, inputs...] + operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, inputs...] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("fn", fn),] From 732e6c32190bef22dcbd91ca4dc9daad6436d43f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 21:48:39 -0500 Subject: [PATCH 16/28] feat: return values --- ext/ReactantPythonCallExt/pycall.jl | 4 +--- src/Ops.jl | 37 +++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 0788daef7e..f99a30b8ac 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -120,7 +120,7 @@ function overlayed_pycall_with_triton( # we are compiling here + lowering again inside enzymejax ccinfo = triton.compile(src; target=target, options=options.__dict__) - @opcall triton_call( + return @opcall triton_call( pyconvert(String, ccinfo.asm["source"]), filter(x -> x isa Reactant.TracedType, args)...; func_name=pyconvert(String, ccinfo.metadata.name), @@ -131,6 +131,4 @@ function overlayed_pycall_with_triton( block_y=@opcall(constant(blocks[2])), block_z=@opcall(constant(blocks[3])), ) - - return nothing end diff --git a/src/Ops.jl b/src/Ops.jl index 3f81e228f2..135a466b9e 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1851,11 +1851,28 @@ function triton_call( block_y::TracedRNumber{<:Integer}, block_z::TracedRNumber{<:Integer}, location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), - # TODO: other kwargs ) _, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location) - triton_ext.call( + result_types = MLIR.IR.Type[] + output_operand_aliases = MLIR.IR.Attribute[] + output_to_arg = Int[] + for (i, arg) in enumerate(args) + if arg isa TracedRArray + push!(result_types, mlir_type(typeof(arg), size(arg))) + push!( + output_operand_aliases, + MLIR.IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, Int64(i - 1), 0, C_NULL + ), + ), + ) + push!(output_to_arg, i) + end + end + + results = triton_ext.call( grid_x.mlir_data, grid_y.mlir_data, grid_z.mlir_data, @@ -1864,11 +1881,23 @@ function triton_call( block_z.mlir_data, [Reactant.TracedUtils.get_mlir_data(a) for a in args]; fn=symref, - result_0=MLIR.IR.Type[], + result_0=result_types, location, + output_operand_aliases, ) - return nothing + array_results = () + for i in 1:MLIR.IR.nresults(results) + arg = args[output_to_arg[i]] + array_results = ( + array_results..., + Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}( + (), MLIR.IR.result(results, i), size(arg) + ), + ) + end + length(array_results) == 1 && return array_results[1] + return array_results end """ From 427b54aa50152908e4b03b335598a866f7795ffa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 4 Oct 2025 15:41:53 -0500 Subject: [PATCH 17/28] feat: lowering triton now works --- src/CompileOptions.jl | 1 + src/Compiler.jl | 44 ++++++++++++++++++++++++++++++++++++++----- src/Ops.jl | 2 +- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index e8cac78be6..f70e63f0be 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -229,6 +229,7 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :no_triton, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index 5f5fca2a5b..045ed12b53 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1851,12 +1851,14 @@ function compile_mlir!( [ "mark-func-memory-effects", opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes_with_triton, + opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", @@ -1878,6 +1880,7 @@ function compile_mlir!( "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, @@ -1888,6 +1891,31 @@ function compile_mlir!( ), "all", ) + elseif compile_options.optimization_passes === :no_triton + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + ["mark-func-memory-effects", opt_passes] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + legalize_chlo_to_stablehlo..., + opt_passes2, + ] + end, + ',', + ), + "before_kernel", + ) elseif compile_options.optimization_passes === :before_kernel run_pass_pipeline!( mod, @@ -1920,13 +1948,14 @@ function compile_mlir!( if compile_options.raise_first [ "mark-func-memory-effects", - opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes_with_triton, + opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", @@ -1946,6 +1975,7 @@ function compile_mlir!( "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, raise_passes, ] @@ -1973,6 +2003,7 @@ function compile_mlir!( "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, ] end, @@ -2046,6 +2077,7 @@ function compile_mlir!( "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes_with_triton, + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, @@ -2063,7 +2095,8 @@ function compile_mlir!( if compile_options.raise_first [ "mark-func-memory-effects", - opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", @@ -2078,9 +2111,10 @@ function compile_mlir!( "mark-func-memory-effects", opt_passes, "enzyme-batch", - opt_passes2, + opt_passes_with_triton, enzyme_pass, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math", + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, diff --git a/src/Ops.jl b/src/Ops.jl index 135a466b9e..dc21a2609d 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1864,7 +1864,7 @@ function triton_call( output_operand_aliases, MLIR.IR.Attribute( MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), 0, C_NULL, Int64(i - 1), 0, C_NULL + MLIR.IR.context(), 1, Int64[i - 1], Int64(i - 1), 0, C_NULL ), ), ) From a07bb1319a3a3295fab9465a5d00fb4d1a300766 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 4 Oct 2025 17:28:19 -0500 Subject: [PATCH 18/28] feat: triton working end to end --- src/Ops.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index dc21a2609d..9a7eea4ec3 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1889,12 +1889,11 @@ function triton_call( array_results = () for i in 1:MLIR.IR.nresults(results) arg = args[output_to_arg[i]] - array_results = ( - array_results..., - Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}( - (), MLIR.IR.result(results, i), size(arg) - ), + res = Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}( + (), MLIR.IR.result(results, i), size(arg) ) + copyto!(arg, res) + array_results = (array_results..., res) end length(array_results) == 1 && return array_results[1] return array_results From ed521e2c1a90610d8b64c3b5086629e9e406931f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 08:52:15 -0500 Subject: [PATCH 19/28] chore: bump commit --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 7a06caeac6..763801a229 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "6137ac98e710adf6f4e953bf441db4e25b2db40f" +ENZYMEXLA_COMMIT = "f2072aa2031eb6a1d5d1972d3a95340fb67c9480" ENZYMEXLA_SHA256 = "" From bea3b16e6731008f5017075c237e566e84826965 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 11:16:56 -0500 Subject: [PATCH 20/28] fix: extra export + naming --- deps/ReactantExtra/BUILD | 3 --- src/Compiler.jl | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 5d8ba5eda4..0d7ef130c1 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -979,9 +979,6 @@ cc_library( "-Wl,-exported_symbol,_ReactantFuncSetArgAttr", "-Wl,-exported_symbol,_ReactantHermeticCudaGetVersion", "-Wl,-exported_symbol,_ReactantCudaDriverGetVersion", - "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor", - "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor", - "-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads", "-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions", "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor", "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor", diff --git a/src/Compiler.jl b/src/Compiler.jl index 045ed12b53..4ccc96e4cd 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1914,7 +1914,7 @@ function compile_mlir!( end, ',', ), - "before_kernel", + "no_triton", ) elseif compile_options.optimization_passes === :before_kernel run_pass_pipeline!( From 4a94a22a287bd1ee4f692d5f963d474436b85a49 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 12:30:57 -0500 Subject: [PATCH 21/28] feat: allow grid/blocks via a function [skip ci] --- ext/ReactantPythonCallExt/pycall.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index f99a30b8ac..3072c30c81 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -47,8 +47,14 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end -normalize_grid_and_blocks(grid::Integer) = normalize_grid_and_blocks((grid,)) -function normalize_grid_and_blocks(grid::Dims{N}) where {N} +function normalize_grid_and_blocks(grid_fn, metadata) + return normalize_grid_and_blocks(grid_fn(metadata), metadata) +end + +function normalize_grid_and_blocks(grid::Integer, metadata) + return normalize_grid_and_blocks((grid,), metadata) +end +function normalize_grid_and_blocks(grid::Dims{N}, metadata) where {N} @assert N <= 3 @assert all(grid .> 0) return (grid..., ntuple(_ -> 1, 3 - N)...) @@ -71,9 +77,6 @@ function overlayed_pycall_with_triton( ) triton = tritonptr[] - grid = normalize_grid_and_blocks(grid) - blocks = normalize_grid_and_blocks(blocks) - mapped = map(signature_string, args) signature = first.(mapped) # TODO: are hints actually correctly set? @@ -120,6 +123,9 @@ function overlayed_pycall_with_triton( # we are compiling here + lowering again inside enzymejax ccinfo = triton.compile(src; target=target, options=options.__dict__) + grid = normalize_grid_and_blocks(grid, ccinfo.metadata) + blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata) + return @opcall triton_call( pyconvert(String, ccinfo.asm["source"]), filter(x -> x isa Reactant.TracedType, args)...; From b017d5b0c351a3239a6c7fbeb5b081c4c171b40c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 15:53:51 -0500 Subject: [PATCH 22/28] feat: use new device properties [skip ci] --- deps/ReactantExtra/WORKSPACE | 2 +- ext/ReactantPythonCallExt/pycall.jl | 35 ++++++++++++++++++++--------- src/Compiler.jl | 25 ++++++++++++++++----- src/Ops.jl | 13 ++++++++--- 4 files changed, 55 insertions(+), 20 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 763801a229..159e549d49 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "f2072aa2031eb6a1d5d1972d3a95340fb67c9480" +ENZYMEXLA_COMMIT = "8221b6147f497592205e6f558b1609e2964f3330" ENZYMEXLA_SHA256 = "" diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 3072c30c81..19cd1634cc 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -47,14 +47,16 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end -function normalize_grid_and_blocks(grid_fn, metadata) - return normalize_grid_and_blocks(grid_fn(metadata), metadata) +function normalize_grid_and_blocks(grid_fn, metadata, device_properties) + return normalize_grid_and_blocks( + grid_fn(metadata, device_properties), metadata, device_properties + ) end -function normalize_grid_and_blocks(grid::Integer, metadata) - return normalize_grid_and_blocks((grid,), metadata) +function normalize_grid_and_blocks(grid::Integer, metadata, device_properties) + return normalize_grid_and_blocks((grid,), metadata, device_properties) end -function normalize_grid_and_blocks(grid::Dims{N}, metadata) where {N} +function normalize_grid_and_blocks(grid::Dims{N}, metadata, device_properties) where {N} @assert N <= 3 @assert all(grid .> 0) return (grid..., ntuple(_ -> 1, 3 - N)...) @@ -71,8 +73,9 @@ function overlayed_pycall_with_triton( args...; grid, blocks, - num_warps::Integer=1, + num_warps::Integer=4, num_stages::Integer=3, + num_ctas::Integer=1, hints=nothing, ) triton = tritonptr[] @@ -105,16 +108,23 @@ function overlayed_pycall_with_triton( fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs ) + # TODO: pass the device/client here from `compile` + client = Reactant.XLA.default_backend() + @assert Reactant.XLA.platform_name(client) == "cuda" + device = Reactant.XLA.default_device(client) + device_properties = Reactant.XLA.device_properties(device) + target = triton.backends.compiler.GPUTarget( - "cuda", - parse(Int, Reactant.Compiler.cubinChip[][4:end]), - Reactant.Compiler.cuWarpSize[], + Reactant.XLA.platform_name(client), + parse(Int, "$(device_properties.major)$(device_properties.minor)"), + device_properties.warp_size, ) backend = triton.compiler.make_backend(target) options = backend.parse_options( pydict( "num_warps" => num_warps, "num_stages" => num_stages, + "num_ctas" => num_ctas, "extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)), ), ) @@ -123,8 +133,8 @@ function overlayed_pycall_with_triton( # we are compiling here + lowering again inside enzymejax ccinfo = triton.compile(src; target=target, options=options.__dict__) - grid = normalize_grid_and_blocks(grid, ccinfo.metadata) - blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata) + grid = normalize_grid_and_blocks(grid, ccinfo.metadata, device_properties) + blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata, device_properties) return @opcall triton_call( pyconvert(String, ccinfo.asm["source"]), @@ -136,5 +146,8 @@ function overlayed_pycall_with_triton( block_x=@opcall(constant(blocks[1])), block_y=@opcall(constant(blocks[2])), block_z=@opcall(constant(blocks[3])), + # The following are written to module attributes and restored later on + num_ctas, + num_warps, ) end diff --git a/src/Compiler.jl b/src/Compiler.jl index 4ccc96e4cd..2e21837230 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -703,6 +703,7 @@ function optimization_passes( max_constant_threshold::Int=1024, backend::String="gpu", enable_triton_passes::Bool=false, + device_properties::Union{Nothing,XLA.DeviceProperties}=nothing, ) transform_passes_list = [ "patterns=compare_op_canon<16>", @@ -1302,14 +1303,20 @@ function optimization_passes( end push!(passes, func_passes) if enable_triton_passes && backend == "cuda" - push!(passes, triton_optimization_passes()) + push!(passes, triton_optimization_passes(device_properties)) end return join(passes, ',') end # https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc # To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes -function triton_optimization_passes() +function triton_optimization_passes(device_properties) + @assert device_properties !== nothing "Device properties must be provided to run \ + triton passes. This might happen if you are \ + compiling a triton kernel for non-cuda backend." + major_version = device_properties.major + minor_version = device_properties.minor + all_passes = join( [ "canonicalize", @@ -1320,7 +1327,9 @@ function triton_optimization_passes() "cse", "symbol-dce", "triton-loop-unroll", - "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}", + "preserve-triton-warps-ctas{save=true restore=false}", + "convert-triton-to-tritongpu{target=cuda:$(major_version)$(minor_version)}", + "preserve-triton-warps-ctas{save=false restore=true}", "tritongpu-coalesce", "tritongpu-F32DotTC", "triton-nvidia-gpu-plan-cta", @@ -1740,6 +1749,9 @@ function compile_mlir!( toolkit = XLA.CUDA_DATA_DIR[] + default_device = XLA.default_device(client) + device_properties = XLA.device_properties(default_device) + if backend == "cpu" || backend == "tpu" kern = "lower-kernel{backend=cpu},canonicalize" if backend == "tpu" @@ -1754,9 +1766,7 @@ function compile_mlir!( "lower-kernel,canonicalize" end - device_properties = XLA.device_properties(XLA.default_device(client)) cubinChip = "sm_$(device_properties.major)$(device_properties.minor)" - if DEBUG_KERNEL[] curesulthandler = dlsym( Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" @@ -1787,6 +1797,7 @@ function compile_mlir!( lower_comms, backend, enable_triton_passes=false, + device_properties, ) opt_passes2 = optimization_passes( compile_options; @@ -1795,6 +1806,7 @@ function compile_mlir!( lower_comms, backend, enable_triton_passes=false, + device_properties, ) opt_passes_with_triton = optimization_passes( compile_options; @@ -1803,6 +1815,7 @@ function compile_mlir!( lower_comms, backend, enable_triton_passes=true, + device_properties, ) raise_passes = if raise isa String @@ -1824,6 +1837,7 @@ function compile_mlir!( recognize_comms, lower_comms, backend, + device_properties, ) result = result * "," * opt_passes_dus_to_concat end @@ -2148,6 +2162,7 @@ function compile_mlir!( recognize_comms, lower_comms, backend, + device_properties, ), "post_op_transpose_reshape", ) diff --git a/src/Ops.jl b/src/Ops.jl index 9a7eea4ec3..7e778041c0 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1837,7 +1837,7 @@ function _extract_function( error("hlo_call: could not find function $func_name in the provided module") end - return fn, symref + return fn, symref, moduleop end function triton_call( @@ -1850,9 +1850,16 @@ function triton_call( block_x::TracedRNumber{<:Integer}, block_y::TracedRNumber{<:Integer}, block_z::TracedRNumber{<:Integer}, + num_ctas::Integer=1, + num_warps::Integer=4, location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), ) - _, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location) + _, symref, modop = _extract_function( + mlir_code; func_name, func_op_kind="tt.func", location + ) + + MLIR.IR.attr!(modop, "ttg.num-wraps", MLIR.IR.Attribute(Int32(num_warps))) + MLIR.IR.attr!(modop, "ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas))) result_types = MLIR.IR.Type[] output_operand_aliases = MLIR.IR.Attribute[] @@ -1929,7 +1936,7 @@ julia> Reactant.@jit( func_name="main", location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), ) - fn, symref = _extract_function(code; func_name, func_op_kind="func.func", location) + fn, symref, _ = _extract_function(code; func_name, func_op_kind="func.func", location) ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) From 6af5f94968900526668b820f4262555bbd34e919 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 19:45:20 -0500 Subject: [PATCH 23/28] feat: correctly set strides + get n_regs --- deps/ReactantExtra/API.cpp | 24 ++++++++++++ deps/ReactantExtra/BUILD | 1 + ext/ReactantPythonCallExt/pycall.jl | 58 +++++++++++++++++++++++------ src/Reactant.jl | 29 +++++++++++++++ 4 files changed, 100 insertions(+), 12 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f8ad3715c0..51e038e816 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -812,6 +812,26 @@ REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops, jlprops->maxThreadsPerMultiProcessor = props.maxThreadsPerMultiProcessor; } +REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary( + const char *binary, const char *fnname, int32_t *regs, int32_t *spills, + int32_t *maxThreads) { + CUfunction fun; + CUmodule mod; + + ReactantHandleCuResult(cuModuleLoadData(&mod, binary)); + ReactantHandleCuResult(cuModuleGetFunction(&fun, mod, fnname)); + + ReactantHandleCuResult( + cuFuncGetAttribute(regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + ReactantHandleCuResult( + cuFuncGetAttribute(spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + *spills /= 4; + ReactantHandleCuResult(cuFuncGetAttribute( + maxThreads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun)); + + return; +} + #else REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; } @@ -827,6 +847,10 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return 0; } REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops, int32_t device_id) {} +REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary( + const char *binary, const char *fnname, int32_t *regs, int32_t *spills, + int32_t *maxThreads) {} + #endif REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) { diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 0d7ef130c1..f44e491c67 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -984,6 +984,7 @@ cc_library( "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor", "-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads", "-Wl,-exported_symbol,_ReactantCudaDeviceGetProperties", + "-Wl,-exported_symbol,_ReactantCudaGetRegsSpillsMaxThreadsFromBinary", "-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId", "-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId", "-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId", diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 19cd1634cc..3e8e5f5cb9 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -47,16 +47,25 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end -function normalize_grid_and_blocks(grid_fn, metadata, device_properties) - return normalize_grid_and_blocks( - grid_fn(metadata, device_properties), metadata, device_properties - ) +struct TritonMetadata{CK,MD,DP} + compiled_kernel::CK + metadata::MD + device_properties::DP + num_warps::Int + num_stages::Int + num_ctas::Int + num_regs::Int + num_spills::Int + max_num_threads::Int end -function normalize_grid_and_blocks(grid::Integer, metadata, device_properties) - return normalize_grid_and_blocks((grid,), metadata, device_properties) +function normalize_grid_and_blocks(grid_fn, metadata) + return normalize_grid_and_blocks(grid_fn(metadata), metadata) +end +function normalize_grid_and_blocks(grid::Integer, metadata) + return normalize_grid_and_blocks((grid,), metadata) end -function normalize_grid_and_blocks(grid::Dims{N}, metadata, device_properties) where {N} +function normalize_grid_and_blocks(grid::Dims{N}, metadata) where {N} @assert N <= 3 @assert all(grid .> 0) return (grid..., ntuple(_ -> 1, 3 - N)...) @@ -131,15 +140,40 @@ function overlayed_pycall_with_triton( # Currently we are doing a double compilation here. can we do better? # we are compiling here + lowering again inside enzymejax - ccinfo = triton.compile(src; target=target, options=options.__dict__) + compiled_kernel = triton.compile(src; target=target, options=options.__dict__) + + cubin = pyconvert(Vector{UInt8}, compiled_kernel.asm["cubin"]) + fname = pyconvert(String, compiled_kernel.metadata.name) + n_regs, n_spills, n_max_threads = Ref{Int32}(), Ref{Int32}(), Ref{Int32}() + GC.@preserve cubin fname n_regs n_spills n_max_threads begin + @ccall Reactant.MLIR.API.mlir_c.ReactantCudaGetRegsSpillsMaxThreadsFromBinary( + cubin::Ptr{Cvoid}, + fname::Cstring, + n_regs::Ptr{Int32}, + n_spills::Ptr{Int32}, + n_max_threads::Ptr{Int32}, + )::Cvoid + end + + metadata = TritonMetadata( + compiled_kernel, + compiled_kernel.metadata, + device_properties, + num_warps, + num_stages, + num_ctas, + Int(n_regs[]), + Int(n_spills[]), + Int(n_max_threads[]), + ) - grid = normalize_grid_and_blocks(grid, ccinfo.metadata, device_properties) - blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata, device_properties) + grid = normalize_grid_and_blocks(grid, metadata) + blocks = normalize_grid_and_blocks(blocks, metadata) return @opcall triton_call( - pyconvert(String, ccinfo.asm["source"]), + pyconvert(String, compiled_kernel.asm["source"]), filter(x -> x isa Reactant.TracedType, args)...; - func_name=pyconvert(String, ccinfo.metadata.name), + func_name=fname, grid_x=@opcall(constant(grid[1])), grid_y=@opcall(constant(grid[2])), grid_z=@opcall(constant(grid[3])), diff --git a/src/Reactant.jl b/src/Reactant.jl index 7c31f1a8c5..3922f0761f 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -236,6 +236,35 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") +""" + rowmajor_strides(x::AbstractArray) + +Returns the strides of the array `x` assuming that the array is stored in row-major order. +""" +rowmajor_strides(x::AbstractArray) = rowmajor_strides(size(x)) +function rowmajor_strides(sz::NTuple{N,Int}) where {N} + strides = ntuple(_ -> 1, N) + for i in (N - 1):-1:1 + strides = Base.setindex(strides, strides[i + 1] * sz[i + 1], i) + end + return strides +end + +""" + rowmajor_stride(x::AbstractArray, i::Integer) + +Returns the stride of the array `x` at dimension `i` assuming that the array is stored in +row-major order. +""" +rowmajor_stride(x::AbstractArray, i::Integer) = rowmajor_stride(size(x), i) +function rowmajor_stride(sz::NTuple{N,Int}, i::Integer) where {N} + s = 1 + for j in (i + 1):N + s *= sz[j] + end + return s +end + export StackedBatchDuplicated, StackedBatchDuplicatedNoNeed const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} From 3ef41eca39f47f83ae804cdd449712724969c2c5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 21:20:11 -0500 Subject: [PATCH 24/28] test: add some triton tests --- test/integration/triton/low_memory_dropout.jl | 38 ++++++++++++ test/integration/triton/low_memory_dropout.py | 29 +++++++++ test/integration/triton/softmax.jl | 60 +++++++++++++++++++ test/integration/triton/softmax.py | 38 ++++++++++++ test/integration/triton/vector_add.jl | 22 +++++++ test/integration/triton/vector_add.py | 31 ++++++++++ test/runtests.jl | 18 ++++++ 7 files changed, 236 insertions(+) create mode 100644 test/integration/triton/low_memory_dropout.jl create mode 100644 test/integration/triton/low_memory_dropout.py create mode 100644 test/integration/triton/softmax.jl create mode 100644 test/integration/triton/softmax.py create mode 100644 test/integration/triton/vector_add.jl create mode 100644 test/integration/triton/vector_add.py diff --git a/test/integration/triton/low_memory_dropout.jl b/test/integration/triton/low_memory_dropout.jl new file mode 100644 index 0000000000..a03f5fa030 --- /dev/null +++ b/test/integration/triton/low_memory_dropout.jl @@ -0,0 +1,38 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +low_memory_dropout_kernel = pyimport("low_memory_dropout").seeded_dropout_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function seeded_dropout(x::AbstractVector{T}, p::Number, seed) where {T} + output = similar(x) + mask = similar(x, Bool) + low_memory_dropout_kernel( + x, + output, + mask, + length(x), + p, + seed, + 1024; + grid=(cld(length(x), 1024),), + blocks=(1024,), + ) + return output, mask +end + +function apply_dropout(x::AbstractVector{T}, mask::AbstractVector, p::Number) where {T} + return x .* mask ./ (1 - p) +end + +@testset "low_memory_dropout" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 2056)) + + out, mask = @jit seeded_dropout(x_ra, 0.25f0, ConcreteRNumber(123)) + + @test @jit(apply_dropout(x_ra, mask, 0.25f0)) ≈ out + end +end diff --git a/test/integration/triton/low_memory_dropout.py b/test/integration/triton/low_memory_dropout.py new file mode 100644 index 0000000000..ad32ac0014 --- /dev/null +++ b/test/integration/triton/low_memory_dropout.py @@ -0,0 +1,29 @@ +import triton +import triton.language as tl + + +@triton.jit +def seeded_dropout_kernel( + x_ptr, + output_ptr, + mask_ptr, + n_elements, + p, + seed, + BLOCK_SIZE: tl.constexpr, +): + # compute memory offsets of elements handled by this instance + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + mask_out = tl.where(x_keep, 1.0, 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + tl.store(mask_ptr + offsets, mask_out, mask=mask) diff --git a/test/integration/triton/softmax.jl b/test/integration/triton/softmax.jl new file mode 100644 index 0000000000..57ca96e7db --- /dev/null +++ b/test/integration/triton/softmax.jl @@ -0,0 +1,60 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +softmax_kernel = pyimport("softmax").softmax_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function softmax_naive(x::AbstractMatrix{T}) where {T} + x_max = maximum(x; dims=1) + z = x .- x_max + num = exp.(z) + denom = sum(num; dims=1) + return num ./ denom +end + +function softmax_triton(x::AbstractMatrix{T}) where {T} + x_transposed = permutedims(x, (2, 1)) # match python array layout + out = similar(x_transposed) + n_rows, n_cols = size(x_transposed) + + function grid_fn(metadata) + occupancy = ( + metadata.device_properties.regs_per_block ÷ + (metadata.num_regs * metadata.device_properties.warp_size * metadata.num_warps) + ) + + num_programs = min( + metadata.device_properties.multi_processor_count * min( + occupancy, + metadata.device_properties.shared_mem_per_block ÷ metadata.metadata.shared, + ), + n_rows, + ) + return num_programs + end + + softmax_kernel( + out, + x_transposed, + Reactant.rowmajor_stride(x_transposed, 1), + Reactant.rowmajor_stride(out, 1), + n_rows, + n_cols, + BLOCK_SIZE, + num_stages; + grid=grid_fn, + blocks=(BLOCK_SIZE,), + ) + + return permutedims(out, (2, 1)) +end + +@testset "softmax" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 132, 2056)) + + @test @jit(softmax_triton(x_ra)) ≈ @jit(softmax_naive(x_ra)) + end +end diff --git a/test/integration/triton/softmax.py b/test/integration/triton/softmax.py new file mode 100644 index 0000000000..0a80c43275 --- /dev/null +++ b/test/integration/triton/softmax.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl + + +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + num_stages: tl.constexpr, +): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) diff --git a/test/integration/triton/vector_add.jl b/test/integration/triton/vector_add.jl new file mode 100644 index 0000000000..548d192cc2 --- /dev/null +++ b/test/integration/triton/vector_add.jl @@ -0,0 +1,22 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +add_kernel = pyimport("vector_add").add_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function vector_add_triton(x::AbstractVector{T}, y::AbstractVector{T}) where {T} + out = similar(x) + add_kernel(x, y, out, length(x), 1024; grid=(cld(length(x), 1024),), blocks=(1024,)) + return out +end + +@testset "vector_add" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 2096)) + y_ra = Reactant.to_rarray(rand(Float32, 2096)) + + @test @jit(vector_add_triton(x_ra, y_ra)) ≈ @jit(x_ra .+ y_ra) + end +end diff --git a/test/integration/triton/vector_add.py b/test/integration/triton/vector_add.py new file mode 100644 index 0000000000..6b04d51a7d --- /dev/null +++ b/test/integration/triton/vector_add.py @@ -0,0 +1,31 @@ +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) diff --git a/test/runtests.jl b/test/runtests.jl index f812deee5c..81cd5c3f99 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,6 +59,24 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) nranks = 2 run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) end + @testset "Triton" begin + @safetestset "vector_add" include("integration/triton/vector_add.jl") + @safetestset "softmax" include("integration/triton/softmax.jl") + # @safetestset "matmul" include("integration/triton/matmul.jl") + @safetestset "low_memory_dropout" include( + "integration/triton/low_memory_dropout.jl" + ) + # @safetestset "layer norm" include("integration/triton/layer_norm.jl") + # @safetestset "attention" include("integration/triton/attention.jl") + # @safetestset "libdevice" include("integration/triton/libdevice.jl") + # @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl") + # @safetestset "persistant matmul" include( + # "integration/triton/persistant_matmul.jl" + # ) + # @safetestset "block scaled matmul" include( + # "integration/triton/block_scaled_matmul.jl" + # ) + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" From 1b84dd9c5cde3d1157c32d4fa211d80673e27aae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Oct 2025 22:11:39 -0500 Subject: [PATCH 25/28] test: layer_norm + libdevice --- test/integration/triton/layer_norm.jl | 67 +++++++++++++++++++++++++++ test/integration/triton/layer_norm.py | 52 +++++++++++++++++++++ test/integration/triton/libdevice.jl | 21 +++++++++ test/integration/triton/libdevice.py | 19 ++++++++ test/runtests.jl | 4 +- 5 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 test/integration/triton/layer_norm.jl create mode 100644 test/integration/triton/layer_norm.py create mode 100644 test/integration/triton/libdevice.jl create mode 100644 test/integration/triton/libdevice.py diff --git a/test/integration/triton/layer_norm.jl b/test/integration/triton/layer_norm.jl new file mode 100644 index 0000000000..40adf866f6 --- /dev/null +++ b/test/integration/triton/layer_norm.jl @@ -0,0 +1,67 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +layer_norm_kernel = pyimport("layer_norm").layer_norm_fwd_fused + +function layer_norm_triton( + x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T} +) where {T} + x_transposed = permutedims(x, (2, 1)) # match python array layout + y = similar(x_transposed) + M, N = size(x_transposed) + mean = similar(x_transposed, Float32, M) + rstd = similar(x_transposed, Float32, M) + + max_fused_size = 65536 ÷ sizeof(T) + block_size = min(max_fused_size, nextpow(2, N)) + + if N > block_size + throw(ArgumentError("This layer norm doesn't support feature dim >= 64KB.")) + end + + num_warps = min(max(block_size ÷ 256, 1), 8) + + layer_norm_kernel( + x_transposed, + y, + weight, + bias, + mean, + rstd, + Reactant.rowmajor_stride(x_transposed, 1), + N, + 1.0f-5, + block_size; + num_warps=num_warps, + num_ctas=1, + grid=(M,), + blocks=(block_size,), + ) + + return permutedims(y, (2, 1)), mean, rstd +end + +function layer_norm_naive( + x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T} +) where {T} + mean = sum(x; dims=1) ./ size(x, 1) + rstd = 1 ./ sqrt.(sum(abs2, x .- mean; dims=1) ./ size(x, 1) .+ 1e-5) + x_hat = (x .- mean) .* rstd + return x_hat .* weight .+ bias, vec(mean), vec(rstd) +end + +@testset "fused_layer_norm" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 256, 2056)) + weight_ra = Reactant.to_rarray(rand(Float32, 256)) + bias_ra = Reactant.to_rarray(rand(Float32, 256)) + + y_ra1, mean_ra1, rstd_ra1 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra) + y_ra2, mean_ra2, rstd_ra2 = @jit layer_norm_naive(x_ra, weight_ra, bias_ra) + + @test y_ra1 ≈ y_ra2 + @test mean_ra1 ≈ mean_ra2 + @test rstd_ra1 ≈ rstd_ra2 + end +end \ No newline at end of file diff --git a/test/integration/triton/layer_norm.py b/test/integration/triton/layer_norm.py new file mode 100644 index 0000000000..c50e715dcf --- /dev/null +++ b/test/integration/triton/layer_norm.py @@ -0,0 +1,52 @@ +import triton +import triton.language as tl + + +@triton.jit +def layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) diff --git a/test/integration/triton/libdevice.jl b/test/integration/triton/libdevice.jl new file mode 100644 index 0000000000..6376ecd674 --- /dev/null +++ b/test/integration/triton/libdevice.jl @@ -0,0 +1,21 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +asin_kernel = pyimport("libdevice").asin_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function asin_triton(x::AbstractVector{T}) where {T} + out = similar(x) + asin_kernel(x, out, length(x), 1024; grid=(cld(length(x), 1024),), blocks=(1024,)) + return out +end + +@testset "libdevice asin" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 2096)) + + @test @jit(asin_triton(x_ra)) ≈ @jit(asin.(x_ra)) + end +end diff --git a/test/integration/triton/libdevice.py b/test/integration/triton/libdevice.py new file mode 100644 index 0000000000..ac9a199952 --- /dev/null +++ b/test/integration/triton/libdevice.py @@ -0,0 +1,19 @@ +import triton +import triton.language as tl +from triton.language.extra import libdevice + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) diff --git a/test/runtests.jl b/test/runtests.jl index 81cd5c3f99..bda0a84c4f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,9 +66,9 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "low_memory_dropout" include( "integration/triton/low_memory_dropout.jl" ) - # @safetestset "layer norm" include("integration/triton/layer_norm.jl") + @safetestset "layer norm" include("integration/triton/layer_norm.jl") # @safetestset "attention" include("integration/triton/attention.jl") - # @safetestset "libdevice" include("integration/triton/libdevice.jl") + @safetestset "libdevice" include("integration/triton/libdevice.jl") # @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl") # @safetestset "persistant matmul" include( # "integration/triton/persistant_matmul.jl" From e9623d0e229f241d76d7245e26937c9c22754d2e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 02:00:12 -0500 Subject: [PATCH 26/28] fix: partial fix to the blocks --- ext/ReactantPythonCallExt/pycall.jl | 22 +- src/Compiler.jl | 1 + test/integration/triton/layer_norm.jl | 26 +- test/integration/triton/layer_norm.py | 51 ++++ test/integration/triton/libdevice.jl | 2 +- test/integration/triton/low_memory_dropout.jl | 10 +- test/integration/triton/matmul.jl | 61 ++++ test/integration/triton/matmul.py | 264 ++++++++++++++++++ test/integration/triton/softmax.jl | 1 - test/integration/triton/vector_add.jl | 2 +- test/runtests.jl | 4 +- 11 files changed, 405 insertions(+), 39 deletions(-) create mode 100644 test/integration/triton/matmul.jl create mode 100644 test/integration/triton/matmul.py diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 3e8e5f5cb9..426011cbf7 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -59,13 +59,9 @@ struct TritonMetadata{CK,MD,DP} max_num_threads::Int end -function normalize_grid_and_blocks(grid_fn, metadata) - return normalize_grid_and_blocks(grid_fn(metadata), metadata) -end -function normalize_grid_and_blocks(grid::Integer, metadata) - return normalize_grid_and_blocks((grid,), metadata) -end -function normalize_grid_and_blocks(grid::Dims{N}, metadata) where {N} +normalize_grid(grid_fn, metadata) = normalize_grid(grid_fn(metadata), metadata) +normalize_grid(grid::Integer, metadata) = normalize_grid((grid,), metadata) +function normalize_grid(grid::Dims{N}, metadata) where {N} @assert N <= 3 @assert all(grid .> 0) return (grid..., ntuple(_ -> 1, 3 - N)...) @@ -81,7 +77,6 @@ function overlayed_pycall_with_triton( kernel::Py, args...; grid, - blocks, num_warps::Integer=4, num_stages::Integer=3, num_ctas::Integer=1, @@ -118,6 +113,7 @@ function overlayed_pycall_with_triton( ) # TODO: pass the device/client here from `compile` + # TODO: cluster dims client = Reactant.XLA.default_backend() @assert Reactant.XLA.platform_name(client) == "cuda" device = Reactant.XLA.default_device(client) @@ -167,8 +163,7 @@ function overlayed_pycall_with_triton( Int(n_max_threads[]), ) - grid = normalize_grid_and_blocks(grid, metadata) - blocks = normalize_grid_and_blocks(blocks, metadata) + grid = normalize_grid(grid, metadata) return @opcall triton_call( pyconvert(String, compiled_kernel.asm["source"]), @@ -177,10 +172,9 @@ function overlayed_pycall_with_triton( grid_x=@opcall(constant(grid[1])), grid_y=@opcall(constant(grid[2])), grid_z=@opcall(constant(grid[3])), - block_x=@opcall(constant(blocks[1])), - block_y=@opcall(constant(blocks[2])), - block_z=@opcall(constant(blocks[3])), - # The following are written to module attributes and restored later on + block_x=@opcall(constant(num_warps * device_properties.warp_size)), + block_y=@opcall(constant(1)), + block_z=@opcall(constant(1)), num_ctas, num_warps, ) diff --git a/src/Compiler.jl b/src/Compiler.jl index 2e21837230..e3cde30f10 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1949,6 +1949,7 @@ function compile_mlir!( "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", ] end, ',', diff --git a/test/integration/triton/layer_norm.jl b/test/integration/triton/layer_norm.jl index 40adf866f6..1534b268fe 100644 --- a/test/integration/triton/layer_norm.jl +++ b/test/integration/triton/layer_norm.jl @@ -3,9 +3,12 @@ using PythonCall, Reactant, Test pyimport("sys").path.append(@__DIR__) layer_norm_kernel = pyimport("layer_norm").layer_norm_fwd_fused +layer_norm_kernel_v2 = pyimport("layer_norm").layer_norm_fwd_fused_simple + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") function layer_norm_triton( - x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T} + x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T}, simple::Bool ) where {T} x_transposed = permutedims(x, (2, 1)) # match python array layout y = similar(x_transposed) @@ -20,9 +23,7 @@ function layer_norm_triton( throw(ArgumentError("This layer norm doesn't support feature dim >= 64KB.")) end - num_warps = min(max(block_size ÷ 256, 1), 8) - - layer_norm_kernel( + (simple ? layer_norm_kernel_v2 : layer_norm_kernel)( x_transposed, y, weight, @@ -33,10 +34,9 @@ function layer_norm_triton( N, 1.0f-5, block_size; - num_warps=num_warps, + num_warps=min(max(block_size ÷ 256, 1), 8), num_ctas=1, grid=(M,), - blocks=(block_size,), ) return permutedims(y, (2, 1)), mean, rstd @@ -57,11 +57,15 @@ end weight_ra = Reactant.to_rarray(rand(Float32, 256)) bias_ra = Reactant.to_rarray(rand(Float32, 256)) - y_ra1, mean_ra1, rstd_ra1 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra) + y_ra1, mean_ra1, rstd_ra1 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, false) y_ra2, mean_ra2, rstd_ra2 = @jit layer_norm_naive(x_ra, weight_ra, bias_ra) + y_ra3, mean_ra3, rstd_ra3 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, true) - @test y_ra1 ≈ y_ra2 - @test mean_ra1 ≈ mean_ra2 - @test rstd_ra1 ≈ rstd_ra2 + @test_broken y_ra1 ≈ y_ra2 + @test_broken y_ra2 ≈ y_ra3 + @test_broken mean_ra1 ≈ mean_ra2 + @test mean_ra2 ≈ mean_ra3 + @test_broken rstd_ra1 ≈ rstd_ra2 + @test rstd_ra2 ≈ rstd_ra3 end -end \ No newline at end of file +end diff --git a/test/integration/triton/layer_norm.py b/test/integration/triton/layer_norm.py index c50e715dcf..9595491551 100644 --- a/test/integration/triton/layer_norm.py +++ b/test/integration/triton/layer_norm.py @@ -50,3 +50,54 @@ def layer_norm_fwd_fused( y = x_hat * w + b # Write output tl.store(Y + cols, y, mask=mask) + + +@triton.jit +def layer_norm_fwd_fused_simple( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + + # Compute mean - process one element at a time + mean = 0.0 + for i in range(N): + x = tl.load(X + i).to(tl.float32) + mean += x + mean = mean / N + + # Compute variance - process one element at a time + var = 0.0 + for i in range(N): + x = tl.load(X + i).to(tl.float32) + diff = x - mean + var += diff * diff + var = var / N + rstd = 1.0 / tl.sqrt(var + eps) + + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) diff --git a/test/integration/triton/libdevice.jl b/test/integration/triton/libdevice.jl index 6376ecd674..89eee78e99 100644 --- a/test/integration/triton/libdevice.jl +++ b/test/integration/triton/libdevice.jl @@ -8,7 +8,7 @@ const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") function asin_triton(x::AbstractVector{T}) where {T} out = similar(x) - asin_kernel(x, out, length(x), 1024; grid=(cld(length(x), 1024),), blocks=(1024,)) + asin_kernel(x, out, length(x), 1024; grid=(cld(length(x), 1024),)) return out end diff --git a/test/integration/triton/low_memory_dropout.jl b/test/integration/triton/low_memory_dropout.jl index a03f5fa030..48be41490b 100644 --- a/test/integration/triton/low_memory_dropout.jl +++ b/test/integration/triton/low_memory_dropout.jl @@ -10,15 +10,7 @@ function seeded_dropout(x::AbstractVector{T}, p::Number, seed) where {T} output = similar(x) mask = similar(x, Bool) low_memory_dropout_kernel( - x, - output, - mask, - length(x), - p, - seed, - 1024; - grid=(cld(length(x), 1024),), - blocks=(1024,), + x, output, mask, length(x), p, seed, 1024; grid=(cld(length(x), 1024),) ) return output, mask end diff --git a/test/integration/triton/matmul.jl b/test/integration/triton/matmul.jl new file mode 100644 index 0000000000..ea841dd771 --- /dev/null +++ b/test/integration/triton/matmul.jl @@ -0,0 +1,61 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +matmul_kernel = pyimport("matmul").matmul_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function matmul_triton(a::AbstractMatrix{T}, b::AbstractMatrix{T}) where {T} + # a: [M, K] --> aᵀ: [K, M] + # b: [K, N] --> bᵀ: [N, K] + # c: a × b [M, N] --> cᵀ: bᵀ × aᵀ [N, M] + a_transposed = permutedims(a, (2, 1)) # match python array layout + b_transposed = permutedims(b, (2, 1)) # match python array layout + @assert size(b_transposed, 2) == size(a_transposed, 1) "Inner dimensions must match \ + for matmul" + M, K = size(b_transposed) + K, N = size(a_transposed) + + out = similar(a_transposed, T, M, N) # cᵀ + + matmul_kernel( + b_transposed, + a_transposed, + out, + M, + N, + K, + Reactant.rowmajor_stride(b_transposed, 1), + Reactant.rowmajor_stride(b_transposed, 2), + Reactant.rowmajor_stride(a_transposed, 1), + Reactant.rowmajor_stride(a_transposed, 2), + Reactant.rowmajor_stride(out, 1), + Reactant.rowmajor_stride(out, 2), + 64, + 256, + 32, + 8; + grid=(cld(M, 64) * cld(N, 256),), + num_stages=4, + num_warps=4, + ) + + return permutedims(out, (2, 1)) +end + +@testset "matmul" begin + if RunningOnCUDA + @testset for M in (4, 32, 256, 1024), + K in (4, 32, 512, 2048), + N in (4, 32, 256, 1024) + + a = Reactant.to_rarray(rand(Float32, M, K)) + b = Reactant.to_rarray(rand(Float32, K, N)) + + # XXX: shared_memory???? + # XXX: seems to work correctly for small matrices + @test_broken @jit(matmul_triton(a, b)) ≈ @jit(a * b) + end + end +end diff --git a/test/integration/triton/matmul.py b/test/integration/triton/matmul.py new file mode 100644 index 0000000000..f4dafc0318 --- /dev/null +++ b/test/integration/triton/matmul.py @@ -0,0 +1,264 @@ +import triton +import triton.language as tl + + +# XXX: enable and support autotuning +# @triton.autotune( +# configs=[ +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 32, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 32, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 32, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=5, +# num_warps=2, +# ), +# # Good config for fp8 inputs. +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 256, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 256, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 32, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# ], +# key=["M", "N", "K"], +# ) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ----------------------------------------------------------- + # Add some integer bound assumptions. + # This helps to guide integer analysis in the backend to optimize + # load/store offset address calculation + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) diff --git a/test/integration/triton/softmax.jl b/test/integration/triton/softmax.jl index 57ca96e7db..8b842a1008 100644 --- a/test/integration/triton/softmax.jl +++ b/test/integration/triton/softmax.jl @@ -45,7 +45,6 @@ function softmax_triton(x::AbstractMatrix{T}) where {T} BLOCK_SIZE, num_stages; grid=grid_fn, - blocks=(BLOCK_SIZE,), ) return permutedims(out, (2, 1)) diff --git a/test/integration/triton/vector_add.jl b/test/integration/triton/vector_add.jl index 548d192cc2..5a96e3b785 100644 --- a/test/integration/triton/vector_add.jl +++ b/test/integration/triton/vector_add.jl @@ -8,7 +8,7 @@ const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") function vector_add_triton(x::AbstractVector{T}, y::AbstractVector{T}) where {T} out = similar(x) - add_kernel(x, y, out, length(x), 1024; grid=(cld(length(x), 1024),), blocks=(1024,)) + add_kernel(x, y, out, length(x), 1024; grid=(cld(length(x), 1024),)) return out end diff --git a/test/runtests.jl b/test/runtests.jl index bda0a84c4f..41b73ebf86 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,11 +62,11 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @testset "Triton" begin @safetestset "vector_add" include("integration/triton/vector_add.jl") @safetestset "softmax" include("integration/triton/softmax.jl") - # @safetestset "matmul" include("integration/triton/matmul.jl") + # @safetestset "matmul" include("integration/triton/matmul.jl") # XXX @safetestset "low_memory_dropout" include( "integration/triton/low_memory_dropout.jl" ) - @safetestset "layer norm" include("integration/triton/layer_norm.jl") + @safetestset "layer norm" include("integration/triton/layer_norm.jl") # XXX # @safetestset "attention" include("integration/triton/attention.jl") @safetestset "libdevice" include("integration/triton/libdevice.jl") # @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl") From f1df17602138641e57cb7278563d1d06384309c0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 14:10:53 -0500 Subject: [PATCH 27/28] fix: correct launch configuration --- deps/ReactantExtra/WORKSPACE | 3 +-- ext/ReactantPythonCallExt/pycall.jl | 11 ++++++---- src/CompileOptions.jl | 1 + src/Compiler.jl | 29 ++++++++++++++++++++++++--- src/Ops.jl | 12 +++++++++-- test/integration/triton/layer_norm.jl | 14 ++++++------- test/runtests.jl | 2 +- 7 files changed, 53 insertions(+), 19 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 159e549d49..24f589495d 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,8 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "8221b6147f497592205e6f558b1609e2964f3330" - +ENZYMEXLA_COMMIT = "4d71da26119a84662cd6f5252a68a35ca1673eae" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 426011cbf7..f328f6d4ac 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -59,9 +59,9 @@ struct TritonMetadata{CK,MD,DP} max_num_threads::Int end -normalize_grid(grid_fn, metadata) = normalize_grid(grid_fn(metadata), metadata) -normalize_grid(grid::Integer, metadata) = normalize_grid((grid,), metadata) -function normalize_grid(grid::Dims{N}, metadata) where {N} +canonicalize_grid(grid_fn, metadata) = canonicalize_grid(grid_fn(metadata), metadata) +canonicalize_grid(grid::Integer, metadata) = canonicalize_grid((grid,), metadata) +function canonicalize_grid(grid::Dims{N}, metadata) where {N} @assert N <= 3 @assert all(grid .> 0) return (grid..., ntuple(_ -> 1, 3 - N)...) @@ -82,6 +82,7 @@ function overlayed_pycall_with_triton( num_ctas::Integer=1, hints=nothing, ) + @assert num_ctas == 1 "TODO: num_ctas > 1 not supported" triton = tritonptr[] mapped = map(signature_string, args) @@ -163,7 +164,7 @@ function overlayed_pycall_with_triton( Int(n_max_threads[]), ) - grid = normalize_grid(grid, metadata) + grid = canonicalize_grid(grid, metadata) return @opcall triton_call( pyconvert(String, compiled_kernel.asm["source"]), @@ -177,5 +178,7 @@ function overlayed_pycall_with_triton( block_z=@opcall(constant(1)), num_ctas, num_warps, + threads_per_warp=device_properties.warp_size, + enable_source_remat=false, ) end diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index f70e63f0be..e545aa6550 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -230,6 +230,7 @@ function CompileOptions(; :just_batch, :none, :no_triton, + :before_triton_lowering, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index e3cde30f10..6c27898574 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1327,9 +1327,7 @@ function triton_optimization_passes(device_properties) "cse", "symbol-dce", "triton-loop-unroll", - "preserve-triton-warps-ctas{save=true restore=false}", - "convert-triton-to-tritongpu{target=cuda:$(major_version)$(minor_version)}", - "preserve-triton-warps-ctas{save=false restore=true}", + "convert-triton-to-triton-gpu-preserving-module-attributes{target=cuda:$(major_version)$(minor_version)}", "tritongpu-coalesce", "tritongpu-F32DotTC", "triton-nvidia-gpu-plan-cta", @@ -1930,6 +1928,31 @@ function compile_mlir!( ), "no_triton", ) + elseif compile_options.optimization_passes === :before_triton_lowering + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + ["mark-func-memory-effects", opt_passes] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + opt_passes_with_triton, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + legalize_chlo_to_stablehlo..., + opt_passes2, + ] + end, + ',', + ), + "before_triton_lowering", + ) elseif compile_options.optimization_passes === :before_kernel run_pass_pipeline!( mod, diff --git a/src/Ops.jl b/src/Ops.jl index 7e778041c0..79971abf3e 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1852,14 +1852,22 @@ function triton_call( block_z::TracedRNumber{<:Integer}, num_ctas::Integer=1, num_warps::Integer=4, + threads_per_warp::Integer=32, + enable_source_remat::Bool=false, location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), ) _, symref, modop = _extract_function( mlir_code; func_name, func_op_kind="tt.func", location ) - MLIR.IR.attr!(modop, "ttg.num-wraps", MLIR.IR.Attribute(Int32(num_warps))) - MLIR.IR.attr!(modop, "ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas))) + MLIR.IR.attr!(modop, "enzymexla.ttg.num-warps", MLIR.IR.Attribute(Int32(num_warps))) + MLIR.IR.attr!(modop, "enzymexla.ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas))) + MLIR.IR.attr!( + modop, "enzymexla.ttg.threads-per-warp", MLIR.IR.Attribute(Int32(threads_per_warp)) + ) + if enable_source_remat + MLIR.IR.attr!(modop, "enzymexla.ttg.enable-source-remat", MLIR.IR.UnitAttribute()) + end result_types = MLIR.IR.Type[] output_operand_aliases = MLIR.IR.Attribute[] diff --git a/test/integration/triton/layer_norm.jl b/test/integration/triton/layer_norm.jl index 1534b268fe..f9652da235 100644 --- a/test/integration/triton/layer_norm.jl +++ b/test/integration/triton/layer_norm.jl @@ -53,19 +53,19 @@ end @testset "fused_layer_norm" begin if RunningOnCUDA - x_ra = Reactant.to_rarray(rand(Float32, 256, 2056)) - weight_ra = Reactant.to_rarray(rand(Float32, 256)) - bias_ra = Reactant.to_rarray(rand(Float32, 256)) + x_ra = Reactant.to_rarray(rand(Float32, 257, 2056)) + weight_ra = Reactant.to_rarray(rand(Float32, 257)) + bias_ra = Reactant.to_rarray(rand(Float32, 257)) y_ra1, mean_ra1, rstd_ra1 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, false) y_ra2, mean_ra2, rstd_ra2 = @jit layer_norm_naive(x_ra, weight_ra, bias_ra) y_ra3, mean_ra3, rstd_ra3 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, true) - @test_broken y_ra1 ≈ y_ra2 - @test_broken y_ra2 ≈ y_ra3 - @test_broken mean_ra1 ≈ mean_ra2 + @test y_ra1 ≈ y_ra2 + @test y_ra2 ≈ y_ra3 + @test mean_ra1 ≈ mean_ra2 @test mean_ra2 ≈ mean_ra3 - @test_broken rstd_ra1 ≈ rstd_ra2 + @test rstd_ra1 ≈ rstd_ra2 @test rstd_ra2 ≈ rstd_ra3 end end diff --git a/test/runtests.jl b/test/runtests.jl index 41b73ebf86..91389b3231 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,7 +66,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "low_memory_dropout" include( "integration/triton/low_memory_dropout.jl" ) - @safetestset "layer norm" include("integration/triton/layer_norm.jl") # XXX + @safetestset "layer norm" include("integration/triton/layer_norm.jl") # @safetestset "attention" include("integration/triton/attention.jl") @safetestset "libdevice" include("integration/triton/libdevice.jl") # @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl") From 9dde89f1a72f3d9869f45e65cf737398ed5fe843 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Oct 2025 14:12:12 -0500 Subject: [PATCH 28/28] test: missing vars --- test/integration/triton/softmax.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/integration/triton/softmax.jl b/test/integration/triton/softmax.jl index 8b842a1008..815f390754 100644 --- a/test/integration/triton/softmax.jl +++ b/test/integration/triton/softmax.jl @@ -19,6 +19,8 @@ function softmax_triton(x::AbstractMatrix{T}) where {T} out = similar(x_transposed) n_rows, n_cols = size(x_transposed) + BLOCK_SIZE = nextpow(2, n_cols) + function grid_fn(metadata) occupancy = ( metadata.device_properties.regs_per_block ÷ @@ -43,7 +45,7 @@ function softmax_triton(x::AbstractMatrix{T}) where {T} n_rows, n_cols, BLOCK_SIZE, - num_stages; + num_stages=3; grid=grid_fn, )