Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ macro trace(args...)
track_numbers = true
checkpointing = false
mincut = false
tessera = false

expr = first(args)
while length(args) > 1
Expand All @@ -165,6 +166,9 @@ macro trace(args...)
mincut = val
end
args = args[2:end]
elseif args[1] === :tessera
tessera = true
args = args[2:end]
else
break
end
Expand All @@ -191,9 +195,10 @@ macro trace(args...)
end
)
)
return esc(trace_function_definition(__module__, expr))
return esc(trace_function_definition(__module__, expr; tessera))
end
#! format: on
@assert !tessera "tessera annotation is only allowed in front of function definitions"

if Meta.isexpr(expr, :(=))
if Meta.isexpr(expr.args[2], :if)
Expand Down Expand Up @@ -241,10 +246,18 @@ function get_argname(expr)
return expr, expr
end

function trace_function_definition(mod, expr)
function trace_function_definition(mod, expr; tessera=false, tessera_op=nothing)
internal_fn = MacroTools.splitdef(expr)
orig_fname = internal_fn[:name]

tessera_op = if !isnothing(tessera_op)
tessera_op
elseif tessera
String(orig_fname)
else
nothing
end

isfunctor = Meta.isexpr(orig_fname, :(::))
fname = gensym(Symbol(orig_fname, :internal))
internal_fn[:name] = fname
Expand All @@ -269,12 +282,18 @@ function trace_function_definition(mod, expr)
end

if isempty(new_fn[:kwargs])
traced_call_expr = :($(traced_call)($(fname), $(argnames...)))
traced_call_expr =
:($(traced_call)($(fname), $(argnames...); tessera_op=$(tessera_op)))
untraced_call_expr = :($(fname)($(argnames...)))
else
kws = first.(get_argname.(new_fn[:kwargs]))
traced_call_expr =
:($(traced_call)(Core.kwcall, (; $(kws...)), $(fname), $(argnames...)))
traced_call_expr = :($(traced_call)(
Core.kwcall,
(; $(kws...)),
$(fname),
$(argnames...);
tessera_op=$(tessera_op),
))
untraced_call_expr = :(Core.kwcall((; $(kws...)), $(fname), $(argnames...)))
end

Expand Down Expand Up @@ -693,7 +712,7 @@ end

function traced_while end # defined inside Reactant.jl

traced_call(f, args...; kwargs...) = f(args...; kwargs...)
traced_call(f, args...; tessera_op=nothing, kwargs...) = f(args...; kwargs...)

function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
return MacroTools.postwalk(expr) do x
Expand Down
69 changes: 68 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,26 @@ import ..Reactant:
import Reactant: OptimizeCommunicationOptions, ShardyPropagationOptions, CompileOptions
using Reactant_jll: Reactant_jll

import ..ReactantCore: correct_maybe_bcast_call
import ..ReactantCore: correct_maybe_bcast_call, trace_function_definition

const DEBUG_PRINT_CODEGEN = Ref(false)
const DEBUG_DISABLE_RESHARDING = Ref(false)
const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = Ref(false)

const DEBUG_BUFFER_POINTERS_STORE_DICT = Base.IdDict()

# Registry for tessera_op annotations: maps function identity to tessera_op name
const TESSERA_OP_REGISTRY = Dict{UInt,String}()

@inline function get_tessera_op(f::Function)::Union{Nothing,String}
return get(TESSERA_OP_REGISTRY, objectid(f), nothing)
end

@inline function set_tessera_op(f::Function, op_name::String)
TESSERA_OP_REGISTRY[objectid(f)] = op_name
return f
end

@inline function traced_getfield(@nospecialize(obj::Dict), field)
return Base.getindex(obj, field)
end
Expand Down Expand Up @@ -2809,6 +2821,61 @@ macro jit(args...)
#! format: on
end

"""
@tessera_op(name) function foo(...) ... end

Marks a function with a tessera operation attribute. When this function is called
during compilation, the tessera_op will be automatically passed through to the MLIR operation.

# Arguments
- `name::String`: The name of the tessera operation (e.g., "inv", "matmul")

# Example
```julia
@tessera_op("inv") function matrix_inverse(A::Matrix)
return inv(A)
end

# When called during compilation, tessera_op="inv" is passed automatically
result = @compile matrix_inverse(some_matrix)
```
"""
macro tessera_op(name, func_expr)
# Validate that name is a string literal
if !Meta.isexpr(name, :string) && !(name isa String)
error("@tessera_op expects a string literal as the first argument, got: $(name)")
end

# Extract the actual string value
op_name = if Meta.isexpr(name, :string)
name.args[1]
elseif name isa String
name
else
error("Could not extract string from @tessera_op argument")
end

# Validate that func_expr is a function definition
if !(func_expr.head in (:function, :(=)))
error("@tessera_op(\"$op_name\") must be followed by a function definition")
end

# Get function name from expression
fname_expr = func_expr.args[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
fname_expr = func_expr.args[1]


# Extract just the symbol if it's a call expression
fname_sym = fname_expr isa Symbol ? fname_expr : fname_expr.args[1]

traced_expr = trace_function_definition(__module__, func_expr; tessera_op=op_name)

return quote
# Define the function and register tessera_op
$(esc(traced_expr))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
$(esc(traced_expr))

Compiler.set_tessera_op($(esc(fname_sym)), $op_name)
$(esc(fname_sym))
end
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

function compile_call_expr(_mod, compiler, options::Dict, args...)
while length(args) > 1
option, args = args[1], args[2:end]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
option, args = args[1], args[2:end]

Expand Down
7 changes: 5 additions & 2 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ function ReactantCore.traced_if(
return @opcall if_condition(cond, true_fn, false_fn, args...; track_numbers)
end

function ReactantCore.traced_call(f::Function, args...)
return @opcall call(f, args...)
function ReactantCore.traced_call(f::Function, args...; tessera_op=nothing)
if isnothing(tessera_op)
tessera_op = Reactant.Compiler.get_tessera_op(f)
end
return @opcall call(f, args...; tessera_op)
end

function ReactantCore.traced_while(
Expand Down
13 changes: 11 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3017,14 +3017,20 @@ result = Ops.case(
return corrected_traced_results
end

@noinline function call(f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__))
@noinline function call(
f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__), tessera_op=nothing
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

seen = Reactant.OrderedIdDict()
cache_key = []
cache_key = Any[f, tessera_op]
Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes)
cache = Reactant.Compiler.callcache()
if haskey(cache, cache_key)
# cache lookup:
(; f_name, mlir_result_types, traced_result, mutated_args, linear_results, fnwrapped, argprefix, resprefix, resargprefix) = cache[cache_key]
if !isnothing(tessera_op)
MLIR.IR.setattr!(fnwrapped, "tessera_op", MLIR.IR.Attribute(tessera_op))
end
else
f_name = String(gensym(Symbol(f)))

Expand Down Expand Up @@ -3059,6 +3065,9 @@ end
resprefix,
resargprefix,
)
if !isnothing(tessera_op)
MLIR.IR.setattr!(temp.f, "tessera_op", MLIR.IR.Attribute(tessera_op))
end
end

seen_cache = Reactant.OrderedIdDict()
Expand Down
3 changes: 2 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ include("Overlay.jl")
# Serialization
include("serialization/Serialization.jl")

using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile
using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, @tessera_op, traced_getfield, compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, @tessera_op, traced_getfield, compile
using .Compiler:
@compile, @code_hlo, @code_mhlo, @jit, @code_xla, @tessera_op, traced_getfield, compile

export ConcreteRArray,
ConcreteRNumber,
ConcretePJRTArray,
Expand All @@ -281,6 +281,7 @@ export ConcreteRArray,
@code_xla,
@jit,
@trace,
@tessera_op,
within_compile

const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()
Expand Down
21 changes: 21 additions & 0 deletions test/tessera.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Reactant, Test

@tessera_op "reciprocal" function reciprocal(x)
return 1 ./ x
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@trace tessera function foo(x)
return sin.(sum(x) .+ x)
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@testset "Tessera Annotation Tests" begin
x = Reactant.to_rarray(rand(3))
# if optimize=false is not set, the function is inlined.
hlo = repr(@code_hlo optimize = false reciprocal(x))
@test occursin("tessera_op = \"reciprocal\"", hlo)

hlo2 = repr(@code_hlo optimize = false foo(x))
@test occursin("tessera_op = \"foo\"", hlo2)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

Loading