Skip to content

Commit 816e789

Browse files
Pangorawmofeing
andauthored
Add Ops.hlo_call(::String, args...) (EnzymeAD#358)
* special case String and Module in make_tracer * implement Ops.hlo_call * formatting * Update src/Ops.jl Co-authored-by: Sergio Sánchez Ramírez <[email protected]> * SymbolTable: fix lookup * cache and more validation, also specify name to call * error if not func.func * only do special things for func.func * symbol_rename * add multiple call test * rename then remove from parsed module --------- Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
1 parent e40d715 commit 816e789

File tree

4 files changed

+249
-3
lines changed

4 files changed

+249
-3
lines changed

src/Ops.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,4 +1046,127 @@ function compare(
10461046
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
10471047
end
10481048

1049+
# Generate a unique name given a module hash and a function name.
1050+
function _hlo_call_name(orig_name, module_suffix)
1051+
return orig_name * "_hlo_call_" * module_suffix
10491052
end
1053+
1054+
"""
1055+
Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1056+
1057+
Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1058+
with the provided arguments and return a tuple for each result of the call.
1059+
1060+
```julia-repl
1061+
julia> Reactant.@jit(
1062+
Ops.hlo_call(
1063+
\"\"\"
1064+
module {
1065+
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1066+
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1067+
return %0 : tensor<3xf32>
1068+
}
1069+
}
1070+
\"\"\",
1071+
Reactant.to_rarray(Float32[1, 2, 3]),
1072+
Reactant.to_rarray(Float32[1, 2, 3]),
1073+
)
1074+
)
1075+
(ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1076+
```
1077+
"""
1078+
function hlo_call(
1079+
code,
1080+
args...;
1081+
func_name="main",
1082+
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
1083+
)
1084+
module_suffix = string(hash(code); base=16)
1085+
name_to_call = _hlo_call_name(func_name, module_suffix)
1086+
1087+
current_module = MLIR.IR.mmodule()
1088+
top_level_block = MLIR.IR.body(current_module)
1089+
1090+
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
1091+
1092+
fn = MLIR.IR.lookup(
1093+
MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call
1094+
)
1095+
if isnothing(fn)
1096+
new_mod = parse(MLIR.IR.Module, code)
1097+
new_mod_op = MLIR.IR.Operation(new_mod)
1098+
body = MLIR.IR.body(new_mod)
1099+
1100+
operations = collect(MLIR.IR.OperationIterator(body))
1101+
for op in operations
1102+
if MLIR.IR.name(op) == "func.func"
1103+
fn_name = String(MLIR.IR.attr(op, symbol_attr_name))
1104+
if fn_name == func_name
1105+
fn = op
1106+
end
1107+
1108+
new_name = _hlo_call_name(fn_name, module_suffix)
1109+
res = MLIR.IR.LogicalResult(
1110+
MLIR.API.mlirSymbolTableReplaceAllSymbolUses(
1111+
fn_name, new_name, new_mod_op
1112+
),
1113+
)
1114+
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
1115+
1116+
# Set function private
1117+
MLIR.IR.attr!(
1118+
op,
1119+
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
1120+
MLIR.IR.Attribute("private"),
1121+
)
1122+
1123+
# Change function name
1124+
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name))
1125+
end
1126+
end
1127+
1128+
for op in operations
1129+
MLIR.IR.rmfromparent!(op)
1130+
push!(top_level_block, op)
1131+
end
1132+
end
1133+
1134+
if isnothing(fn)
1135+
error("hlo_call: could not find function $func_name in the provided module")
1136+
end
1137+
1138+
ftype_attr = MLIR.IR.attr(fn, "function_type")
1139+
ftype = MLIR.IR.Type(ftype_attr)
1140+
1141+
@assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "hlo_call: all inputs to hlo_call should be reactant arrays"
1142+
@assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name"
1143+
1144+
for (i, arg) in enumerate(args)
1145+
expected_type = MLIR.IR.input(ftype, i)
1146+
arg_type = MLIR.IR.type(arg.mlir_data)
1147+
@assert expected_type == arg_type "hlo_call: argument #$i has the wrong type (expected $expected_type, got $arg_type)"
1148+
end
1149+
1150+
operands = [a.mlir_data for a in args]
1151+
call = MLIR.Dialects.func.call(
1152+
operands;
1153+
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],
1154+
callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1155+
location,
1156+
)
1157+
1158+
return ntuple(MLIR.IR.nresults(call)) do i
1159+
out = MLIR.IR.result(call, i)
1160+
ty = MLIR.IR.type(out)
1161+
sz = MLIR.IR.size(ty)
1162+
T = MLIR.IR.julia_type(eltype(ty))
1163+
N = length(sz)
1164+
if N == 0
1165+
Reactant.TracedRNumber{T}((), out)
1166+
else
1167+
Reactant.TracedRArray{T,N}((), out, sz)
1168+
end
1169+
end
1170+
end
1171+
1172+
end # module Ops

src/Tracing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ function make_tracer(
284284
@assert Base.isconcretetype(RT)
285285
nf = fieldcount(RT)
286286

287+
if TT === Module || TT === String
288+
return prev
289+
end
290+
287291
if ismutabletype(TT)
288292
y = ccall(:jl_new_struct_uninit, Any, (Any,), TT)
289293
seen[prev] = y

src/mlir/IR/SymbolTable.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,17 @@ Base.convert(::Core.Type{API.MlirSymbolTable}, st::SymbolTable) = st.st
2525
Looks up a symbol with the given name in the given symbol table and returns the operation that corresponds to the symbol.
2626
If the symbol cannot be found, returns a null operation.
2727
"""
28-
lookup(st::SymbolTable, name::AbstractString) =
29-
Operation(API.mlirSymbolTableLookup(st, name))
30-
Base.getindex(st::SymbolTable, name::AbstractString) = lookup(st, name)
28+
function lookup(st::SymbolTable, name::AbstractString)
29+
raw_op = API.mlirSymbolTableLookup(st, name)
30+
if raw_op.ptr == C_NULL
31+
nothing
32+
else
33+
Operation(raw_op, false)
34+
end
35+
end
36+
function Base.getindex(st::SymbolTable, name::AbstractString)
37+
@something(lookup(st, name), throw(KeyError(name)))
38+
end
3139

3240
"""
3341
push!(symboltable, operation)

test/ops.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,3 +866,114 @@ end
866866
z = ConcreteRArray([1e-8, 0.001, 2.0])
867867
@test SpecialFunctions.zeta.(Array(s), Array(z)) @jit Ops.zeta(s, z)
868868
end
869+
870+
@testset "hlo_call" begin
871+
x = Float32[1.0, 2.0, 50.0]
872+
y = Float32[-4.0, 0.001, 2.0]
873+
x_reactant = Reactant.to_rarray(x)
874+
y_reactant = Reactant.to_rarray(y)
875+
876+
@test Reactant.@jit(
877+
Ops.hlo_call(
878+
"""
879+
module {
880+
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
881+
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
882+
return %0 : tensor<3xf32>
883+
}
884+
}
885+
""",
886+
x_reactant,
887+
y_reactant,
888+
)
889+
)[1] x .+ y
890+
end
891+
892+
function f_repeat(x, y)
893+
for _ in 1:3
894+
x, = Ops.hlo_call(
895+
"""
896+
module {
897+
func.func @my_add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
898+
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
899+
return %0 : tensor<3xf32>
900+
}
901+
}
902+
""",
903+
x,
904+
y;
905+
func_name="my_add",
906+
)
907+
end
908+
return x
909+
end
910+
911+
@testset "hlo_call: repeat" begin
912+
x = Reactant.to_rarray(randn(Float32, 3))
913+
y = Reactant.to_rarray(randn(Float32, 3))
914+
mod = Reactant.@code_hlo optimize = false f_repeat(x, y)
915+
hlo_ir = repr(mod)
916+
917+
add_pos = findfirst("stablehlo.add", hlo_ir)
918+
@test !isnothing(add_pos)
919+
920+
add_pos = findfirst("stablehlo.add", hlo_ir[last(add_pos):end])
921+
@test isnothing(add_pos)
922+
end
923+
924+
@testset "hlo_call: multiple functions" begin
925+
@test Reactant.@jit(
926+
Ops.hlo_call(
927+
"""
928+
module {
929+
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
930+
%0 = func.call @add(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
931+
return %0 : tensor<3xf32>
932+
}
933+
func.func @add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
934+
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
935+
return %0 : tensor<3xf32>
936+
}
937+
}
938+
""",
939+
Reactant.to_rarray(Float32[1, 2, 3]),
940+
Reactant.to_rarray(Float32[1, 2, 3]),
941+
)
942+
)[1] Float32[2, 4, 6]
943+
end
944+
945+
function f_multiple_hlo_calls(x, y)
946+
x, = Ops.hlo_call(
947+
"""
948+
module {
949+
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
950+
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
951+
return %0 : tensor<3xf32>
952+
}
953+
}
954+
""",
955+
x,
956+
y,
957+
)
958+
return Ops.hlo_call(
959+
"""
960+
module {
961+
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
962+
%0 = stablehlo.multiply %arg0, %arg1 : tensor<3xf32>
963+
return %0 : tensor<3xf32>
964+
}
965+
}
966+
""",
967+
x,
968+
y,
969+
)
970+
end
971+
972+
@testset "hlo_call: multiple hlo_calls" begin
973+
x = Float32[1.0, 2.0, 50.0]
974+
y = Float32[-4.0, 0.001, 2.0]
975+
x_reactant = Reactant.to_rarray(x)
976+
y_reactant = Reactant.to_rarray(y)
977+
978+
@test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] (x .+ y) .* y
979+
end

0 commit comments

Comments
 (0)