-
Notifications
You must be signed in to change notification settings - Fork 45
Add tessera attribute
#1986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add tessera attribute
#1986
Changes from all commits
e8989a3
59c67e6
e389362
6804c97
d3c28d0
9f6498c
8507ee8
3d81375
056b03a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -26,14 +26,26 @@ import ..Reactant: | |||
| import Reactant: OptimizeCommunicationOptions, ShardyPropagationOptions, CompileOptions | ||||
| using Reactant_jll: Reactant_jll | ||||
|
|
||||
| import ..ReactantCore: correct_maybe_bcast_call | ||||
| import ..ReactantCore: correct_maybe_bcast_call, trace_function_definition | ||||
|
|
||||
| const DEBUG_PRINT_CODEGEN = Ref(false) | ||||
| const DEBUG_DISABLE_RESHARDING = Ref(false) | ||||
| const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = Ref(false) | ||||
|
|
||||
| const DEBUG_BUFFER_POINTERS_STORE_DICT = Base.IdDict() | ||||
|
|
||||
| # Registry for tessera_op annotations: maps function identity to tessera_op name | ||||
| const TESSERA_OP_REGISTRY = Dict{UInt,String}() | ||||
|
|
||||
| @inline function get_tessera_op(f::Function)::Union{Nothing,String} | ||||
| return get(TESSERA_OP_REGISTRY, objectid(f), nothing) | ||||
| end | ||||
|
|
||||
| @inline function set_tessera_op(f::Function, op_name::String) | ||||
| TESSERA_OP_REGISTRY[objectid(f)] = op_name | ||||
| return f | ||||
| end | ||||
|
|
||||
| @inline function traced_getfield(@nospecialize(obj::Dict), field) | ||||
| return Base.getindex(obj, field) | ||||
| end | ||||
|
|
@@ -2809,6 +2821,61 @@ macro jit(args...) | |||
| #! format: on | ||||
| end | ||||
|
|
||||
| """ | ||||
| @tessera_op(name) function foo(...) ... end | ||||
|
|
||||
| Marks a function with a tessera operation attribute. When this function is called | ||||
| during compilation, the tessera_op will be automatically passed through to the MLIR operation. | ||||
|
|
||||
| # Arguments | ||||
| - `name::String`: The name of the tessera operation (e.g., "inv", "matmul") | ||||
|
|
||||
| # Example | ||||
| ```julia | ||||
| @tessera_op("inv") function matrix_inverse(A::Matrix) | ||||
| return inv(A) | ||||
| end | ||||
|
|
||||
| # When called during compilation, tessera_op="inv" is passed automatically | ||||
| result = @compile matrix_inverse(some_matrix) | ||||
| ``` | ||||
| """ | ||||
| macro tessera_op(name, func_expr) | ||||
| # Validate that name is a string literal | ||||
| if !Meta.isexpr(name, :string) && !(name isa String) | ||||
| error("@tessera_op expects a string literal as the first argument, got: $(name)") | ||||
| end | ||||
|
|
||||
| # Extract the actual string value | ||||
| op_name = if Meta.isexpr(name, :string) | ||||
| name.args[1] | ||||
| elseif name isa String | ||||
| name | ||||
| else | ||||
| error("Could not extract string from @tessera_op argument") | ||||
| end | ||||
|
|
||||
| # Validate that func_expr is a function definition | ||||
| if !(func_expr.head in (:function, :(=))) | ||||
| error("@tessera_op(\"$op_name\") must be followed by a function definition") | ||||
| end | ||||
|
|
||||
| # Get function name from expression | ||||
| fname_expr = func_expr.args[1] | ||||
|
|
||||
| # Extract just the symbol if it's a call expression | ||||
| fname_sym = fname_expr isa Symbol ? fname_expr : fname_expr.args[1] | ||||
|
|
||||
| traced_expr = trace_function_definition(__module__, func_expr; tessera_op=op_name) | ||||
|
|
||||
| return quote | ||||
| # Define the function and register tessera_op | ||||
| $(esc(traced_expr)) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
| Compiler.set_tessera_op($(esc(fname_sym)), $op_name) | ||||
| $(esc(fname_sym)) | ||||
| end | ||||
| end | ||||
|
|
||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
| function compile_call_expr(_mod, compiler, options::Dict, args...) | ||||
| while length(args) > 1 | ||||
| option, args = args[1], args[2:end] | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
|
|
||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -3017,14 +3017,20 @@ result = Ops.case( | |||
| return corrected_traced_results | ||||
| end | ||||
|
|
||||
| @noinline function call(f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__)) | ||||
| @noinline function call( | ||||
| f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__), tessera_op=nothing | ||||
| ) | ||||
|
|
||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
| seen = Reactant.OrderedIdDict() | ||||
| cache_key = [] | ||||
| cache_key = Any[f, tessera_op] | ||||
| Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes) | ||||
| cache = Reactant.Compiler.callcache() | ||||
| if haskey(cache, cache_key) | ||||
| # cache lookup: | ||||
| (; f_name, mlir_result_types, traced_result, mutated_args, linear_results, fnwrapped, argprefix, resprefix, resargprefix) = cache[cache_key] | ||||
| if !isnothing(tessera_op) | ||||
| MLIR.IR.setattr!(fnwrapped, "tessera_op", MLIR.IR.Attribute(tessera_op)) | ||||
| end | ||||
| else | ||||
| f_name = String(gensym(Symbol(f))) | ||||
|
|
||||
|
|
@@ -3059,6 +3065,9 @@ end | |||
| resprefix, | ||||
| resargprefix, | ||||
| ) | ||||
| if !isnothing(tessera_op) | ||||
| MLIR.IR.setattr!(temp.f, "tessera_op", MLIR.IR.Attribute(tessera_op)) | ||||
| end | ||||
| end | ||||
|
|
||||
| seen_cache = Reactant.OrderedIdDict() | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -268,7 +268,7 @@ include("Overlay.jl") | |||||||
| # Serialization | ||||||||
| include("serialization/Serialization.jl") | ||||||||
|
|
||||||||
| using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile | ||||||||
| using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, @tessera_op, traced_getfield, compile | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||
| export ConcreteRArray, | ||||||||
| ConcreteRNumber, | ||||||||
| ConcretePJRTArray, | ||||||||
|
|
@@ -281,6 +281,7 @@ export ConcreteRArray, | |||||||
| @code_xla, | ||||||||
| @jit, | ||||||||
| @trace, | ||||||||
| @tessera_op, | ||||||||
| within_compile | ||||||||
|
|
||||||||
| const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}() | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,21 @@ | ||||||||
| using Reactant, Test | ||||||||
|
|
||||||||
| @tessera_op "reciprocal" function reciprocal(x) | ||||||||
| return 1 ./ x | ||||||||
| end | ||||||||
|
|
||||||||
|
|
||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||
| @trace tessera function foo(x) | ||||||||
| return sin.(sum(x) .+ x) | ||||||||
| end | ||||||||
|
|
||||||||
|
|
||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||
| @testset "Tessera Annotation Tests" begin | ||||||||
| x = Reactant.to_rarray(rand(3)) | ||||||||
| # if optimize=false is not set, the function is inlined. | ||||||||
| hlo = repr(@code_hlo optimize = false reciprocal(x)) | ||||||||
| @test occursin("tessera_op = \"reciprocal\"", hlo) | ||||||||
|
|
||||||||
| hlo2 = repr(@code_hlo optimize = false foo(x)) | ||||||||
| @test occursin("tessera_op = \"foo\"", hlo2) | ||||||||
| end | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶