Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
12d0d31
feat: julia api to access device properties [skip ci]
avik-pal Oct 16, 2025
77da391
fix: apply suggestion from @avik-pal [skp ci]
avik-pal Oct 16, 2025
c80c3de
chore: bump reactant_jll
avik-pal Oct 17, 2025
fd9f3c9
fix: remove deleted fields [skip ci]
avik-pal Oct 18, 2025
33e9d68
feat: initial triton setup [skip ci]
avik-pal Sep 25, 2025
75f000e
feat: auto-trace triton code
avik-pal Sep 25, 2025
66d3580
feat: copy tt.func into main module [skip ci]
avik-pal Sep 25, 2025
9447d0b
feat: tracing fully functional
avik-pal Sep 27, 2025
0500397
fix: hlo_call
avik-pal Sep 27, 2025
532a302
feat: more triton passes + keep triton func in a separate module
avik-pal Sep 28, 2025
e1d9fc0
feat: put the tt func in a separate module and use symbol ref
avik-pal Sep 28, 2025
fa02d6e
feat: new triton_ext dialect
avik-pal Sep 29, 2025
a3b8cb6
feat: triton tracing works now finally
avik-pal Sep 29, 2025
97c952b
fix: kind of working
avik-pal Sep 29, 2025
b84d519
fix: new API
avik-pal Sep 30, 2025
732e6c3
feat: return values
avik-pal Sep 30, 2025
427b54a
feat: lowering triton now works
avik-pal Oct 4, 2025
a07bb13
feat: triton working end to end
avik-pal Oct 4, 2025
ed521e2
chore: bump commit
avik-pal Oct 16, 2025
bea3b16
fix: extra export + naming
avik-pal Oct 16, 2025
4a94a22
feat: allow grid/blocks via a function [skip ci]
avik-pal Oct 16, 2025
b017d5b
feat: use new device properties [skip ci]
avik-pal Oct 16, 2025
6af5f94
feat: correctly set strides + get n_regs
avik-pal Oct 17, 2025
3ef41ec
test: add some triton tests
avik-pal Oct 17, 2025
1b84dd9
test: layer_norm + libdevice
avik-pal Oct 17, 2025
e9623d0
fix: partial fix to the blocks
avik-pal Oct 17, 2025
f1df176
fix: correct launch configuration
avik-pal Oct 17, 2025
9dde89f
test: missing vars
avik-pal Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4"
jax = ">= 0.6"
tensorflow = ">= 2.17"
numpy = ">= 2"
triton = ">= 3.4"
24 changes: 24 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,26 @@ REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops,
jlprops->maxThreadsPerMultiProcessor = props.maxThreadsPerMultiProcessor;
}

REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
const char *binary, const char *fnname, int32_t *regs, int32_t *spills,
int32_t *maxThreads) {
CUfunction fun;
CUmodule mod;

ReactantHandleCuResult(cuModuleLoadData(&mod, binary));
ReactantHandleCuResult(cuModuleGetFunction(&fun, mod, fnname));

ReactantHandleCuResult(
cuFuncGetAttribute(regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
ReactantHandleCuResult(
cuFuncGetAttribute(spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
*spills /= 4;
ReactantHandleCuResult(cuFuncGetAttribute(
maxThreads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));

return;
}

#else

REACTANT_ABI int32_t ReactantCudaDriverGetVersion() { return 0; }
Expand All @@ -827,6 +847,10 @@ REACTANT_ABI int32_t ReactantCudaDeviceGetWarpSizeInThreads() { return 0; }
REACTANT_ABI void ReactantCudaDeviceGetProperties(DeviceProperties *jlprops,
int32_t device_id) {}

REACTANT_ABI void ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
const char *binary, const char *fnname, int32_t *regs, int32_t *spills,
int32_t *maxThreads) {}

#endif

REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {
Expand Down
19 changes: 19 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,7 @@ cc_library(
"-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor",
"-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads",
"-Wl,-exported_symbol,_ReactantCudaDeviceGetProperties",
"-Wl,-exported_symbol,_ReactantCudaGetRegsSpillsMaxThreadsFromBinary",
"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId",
"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId",
"-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId",
Expand Down Expand Up @@ -1436,6 +1437,24 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "TritonExtJLIncGen",
tbl_outs = [
(
[
"--generator=jl-op-defs",
"--disable-module-wrap=0",
],
"TritonExt.jl",
),
],
tblgen = "//:mlir-jl-tblgen",
td_file = "@enzyme_ad//src/enzyme_ad/jax:Dialect/TritonExt/Ops.td",
deps = [
"@enzyme_ad//src/enzyme_ad/jax:TritonExtDialectTdFiles",
],
)

gentbl_cc_library(
name = "TPUJLIncGen",
tbl_outs = [
Expand Down
3 changes: 1 addition & 2 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"

NSYNC_SHA256 = ""

ENZYMEXLA_COMMIT = "6137ac98e710adf6f4e953bf441db4e25b2db40f"

ENZYMEXLA_COMMIT = "4d71da26119a84662cd6f5252a68a35ca1673eae"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
1 change: 1 addition & 0 deletions deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ for file in [
"MPI.jl",
"MemRef.jl",
"SparseTensor.jl",
"TritonExt.jl",
]
build_file(joinpath(src_dir, "mlir", "Dialects", file))
end
Expand Down
2 changes: 2 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ export default defineConfig({
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
{ text: "Triton", link: "/api/dialects/triton" },
{ text: "TritonExt", link: "/api/dialects/tritonext" },
{ text: "TPU", link: "/api/dialects/tpu" },
{ text: "VHLO", link: "/api/dialects/vhlo" },
],
Expand Down Expand Up @@ -221,6 +222,7 @@ export default defineConfig({
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
{ text: "StableHLO", link: "/api/dialects/stablehlo" },
{ text: "Triton", link: "/api/dialects/triton" },
{ text: "TritonExt", link: "/api/dialects/tritonext" },
{ text: "TPU", link: "/api/dialects/tpu" },
{ text: "VHLO", link: "/api/dialects/vhlo" },
],
Expand Down
11 changes: 11 additions & 0 deletions docs/src/api/dialects/tritonext.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
```@meta
CollapsedDocStrings = true
```

# TritonExt Dialect

Provides extensions to the Triton dialect.

```@autodocs
Modules = [Reactant.MLIR.Dialects.triton_ext]
```
38 changes: 37 additions & 1 deletion ext/ReactantPythonCallExt/ReactantPythonCallExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
module ReactantPythonCallExt

using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist
using PythonCall:
PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance, pytuple
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
using Reactant.Ops: @opcall
using Reactant_jll: Reactant_jll

const jaxptr = Ref{Py}()
const jnpptr = Ref{Py}()

const JAX_TRACING_SUPPORTED = Ref{Bool}(false)

const tritonptr = Ref{Py}()

const TRITON_COMPILE_SUPPORTED = Ref{Bool}(false)

const tfptr = Ref{Py}()
const tf2xlaptr = Ref{Py}()
const npptr = Ref{Py}()
Expand All @@ -33,6 +39,28 @@ const NUMPY_SIMPLE_TYPES = Dict(
ComplexF64 => :complex64,
)

const MLIR_TYPE_STRING = Dict(
Float64 => "fp64",
Float32 => "fp32",
Float16 => "fp16",
Int64 => "i64",
Int32 => "i32",
Int16 => "i16",
Int8 => "i8",
UInt64 => "ui64",
UInt32 => "ui32",
UInt16 => "ui16",
UInt8 => "ui8",
Bool => "i1",
Reactant.F8E4M3FN => "fp8e4nv",
Reactant.F8E5M2FNUZ => "fp8e5b16",
Reactant.F8E4M3FNUZ => "fp8e4b8",
Reactant.F8E5M2 => "fp8e5",
)
if isdefined(Core, :BFloat16)
MLIR_TYPE_STRING[Core.BFloat16] = "bf16"
end

function __init__()
try
jaxptr[] = pyimport("jax")
Expand All @@ -43,6 +71,14 @@ function __init__()
be supported." exception = (err, catch_backtrace())
end

try
tritonptr[] = pyimport("triton")
TRITON_COMPILE_SUPPORTED[] = true
catch err
@warn "Failed to import triton. Compiling jax functions with triton won't be \
supported." exception = (err, catch_backtrace())
end

try
tfptr[] = pyimport("tensorflow")
tfptr[].config.set_visible_devices(pylist(); device_type="GPU")
Expand Down
6 changes: 3 additions & 3 deletions ext/ReactantPythonCallExt/overlays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@reactant_overlay function PythonCall.pycall(f::Py, args...)
@reactant_overlay function PythonCall.pycall(f::Py, args...; kwargs...)
if Reactant.looped_any(Reactant.use_overlayed_version, args)
return pycall_with_jax_tracing(f, args...)
return overlayed_pycall(f, args...; kwargs...)
else
return Base.inferencebarrier(PythonCall.pycall)(f, args...)
return Base.inferencebarrier(PythonCall.pycall)(f, args...; kwargs...)
end
end
149 changes: 148 additions & 1 deletion ext/ReactantPythonCallExt/pycall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe
)
end

function pycall_with_jax_tracing(f::Py, args...)
function overlayed_pycall(f::Py, args...; kwargs...)
@assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[]
# TODO: check for Autotuner and Heutistics as well
if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction)
return overlayed_pycall_with_triton(f, args...; kwargs...)
else
@assert isempty(kwargs) "`kwargs` are not supported for jax traced functions."
return overlayed_pycall_with_jax_tracing(f, args...)
end
end

function overlayed_pycall_with_jax_tracing(f::Py, args...)
JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.")

seen_args = Reactant.OrderedIdDict()
Expand Down Expand Up @@ -35,3 +46,139 @@ function pycall_with_jax_tracing(f::Py, args...)
res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...)
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
end

struct TritonMetadata{CK,MD,DP}
compiled_kernel::CK
metadata::MD
device_properties::DP
num_warps::Int
num_stages::Int
num_ctas::Int
num_regs::Int
num_spills::Int
max_num_threads::Int
end

canonicalize_grid(grid_fn, metadata) = canonicalize_grid(grid_fn(metadata), metadata)
canonicalize_grid(grid::Integer, metadata) = canonicalize_grid((grid,), metadata)
function canonicalize_grid(grid::Dims{N}, metadata) where {N}
@assert N <= 3
@assert all(grid .> 0)
return (grid..., ntuple(_ -> 1, 3 - N)...)
end

signature_string(::TracedRArray{T}) where {T} = "*$(MLIR_TYPE_STRING[T])", nothing
signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothing
signature_string(x::T) where {T<:Number} = string(x), x
signature_string(x) = error("Unsupported argument type: $(typeof(x))")

# TODO: better name for hints?
function overlayed_pycall_with_triton(
kernel::Py,
args...;
grid,
num_warps::Integer=4,
num_stages::Integer=3,
num_ctas::Integer=1,
hints=nothing,
)
@assert num_ctas == 1 "TODO: num_ctas > 1 not supported"
triton = tritonptr[]

mapped = map(signature_string, args)
signature = first.(mapped)
# TODO: are hints actually correctly set?
hints =
hints === nothing ? Dict() : Dict(kernel.arg_names[i - 1] => v for (i, v) in hints)
constants = Dict(
kernel.arg_names[i - 1] => constant for
(i, constant) in enumerate(last.(mapped)) if constant !== nothing
)
for (k, v) in hints
v == 1 && (constants[kernel.arg_names[k - 1]] = v)
end
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)

sigmap = Dict(kernel.arg_names[i - 1] => sig for (i, sig) in enumerate(signature))
for k in keys(constants)
sigmap[k] = "constexpr"
end

for h in values(hints)
@assert h in (1, 16) "Only 1 and 16 are valid hints, got $h"
end
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)

src = triton.compiler.ASTSource(;
fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs
)

# TODO: pass the device/client here from `compile`
# TODO: cluster dims
client = Reactant.XLA.default_backend()
@assert Reactant.XLA.platform_name(client) == "cuda"
device = Reactant.XLA.default_device(client)
device_properties = Reactant.XLA.device_properties(device)

target = triton.backends.compiler.GPUTarget(
Reactant.XLA.platform_name(client),
parse(Int, "$(device_properties.major)$(device_properties.minor)"),
device_properties.warp_size,
)
backend = triton.compiler.make_backend(target)
options = backend.parse_options(
pydict(
"num_warps" => num_warps,
"num_stages" => num_stages,
"num_ctas" => num_ctas,
"extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)),
),
)

# Currently we are doing a double compilation here. can we do better?
# we are compiling here + lowering again inside enzymejax
compiled_kernel = triton.compile(src; target=target, options=options.__dict__)

cubin = pyconvert(Vector{UInt8}, compiled_kernel.asm["cubin"])
fname = pyconvert(String, compiled_kernel.metadata.name)
n_regs, n_spills, n_max_threads = Ref{Int32}(), Ref{Int32}(), Ref{Int32}()
GC.@preserve cubin fname n_regs n_spills n_max_threads begin
@ccall Reactant.MLIR.API.mlir_c.ReactantCudaGetRegsSpillsMaxThreadsFromBinary(
cubin::Ptr{Cvoid},
fname::Cstring,
n_regs::Ptr{Int32},
n_spills::Ptr{Int32},
n_max_threads::Ptr{Int32},
)::Cvoid
end

metadata = TritonMetadata(
compiled_kernel,
compiled_kernel.metadata,
device_properties,
num_warps,
num_stages,
num_ctas,
Int(n_regs[]),
Int(n_spills[]),
Int(n_max_threads[]),
)

grid = canonicalize_grid(grid, metadata)

return @opcall triton_call(
pyconvert(String, compiled_kernel.asm["source"]),
filter(x -> x isa Reactant.TracedType, args)...;
func_name=fname,
grid_x=@opcall(constant(grid[1])),
grid_y=@opcall(constant(grid[2])),
grid_z=@opcall(constant(grid[3])),
block_x=@opcall(constant(num_warps * device_properties.warp_size)),
block_y=@opcall(constant(1)),
block_z=@opcall(constant(1)),
num_ctas,
num_warps,
threads_per_warp=device_properties.warp_size,
enable_source_remat=false,
)
end
2 changes: 2 additions & 0 deletions src/CompileOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ function CompileOptions(;
:canonicalize,
:just_batch,
:none,
:no_triton,
:before_triton_lowering,
]
end

Expand Down
Loading
Loading