@@ -1046,4 +1046,127 @@ function compare(
10461046 return TracedRArray {Bool,ndims(lhs)} ((), res, size (lhs))
10471047end
10481048
1049+ # Generate a unique name given a module hash and a function name.
1050+ function _hlo_call_name (orig_name, module_suffix)
1051+ return orig_name * " _hlo_call_" * module_suffix
10491052end
1053+
1054+ """
1055+ Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1056+
1057+ Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1058+ with the provided arguments and return a tuple for each result of the call.
1059+
1060+ ```julia-repl
1061+ julia> Reactant.@jit(
1062+ Ops.hlo_call(
1063+ \"\"\"
1064+ module {
1065+ func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1066+ %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1067+ return %0 : tensor<3xf32>
1068+ }
1069+ }
1070+ \"\"\" ,
1071+ Reactant.to_rarray(Float32[1, 2, 3]),
1072+ Reactant.to_rarray(Float32[1, 2, 3]),
1073+ )
1074+ )
1075+ (ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1076+ ```
1077+ """
1078+ function hlo_call (
1079+ code,
1080+ args... ;
1081+ func_name= " main" ,
1082+ location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
1083+ )
1084+ module_suffix = string (hash (code); base= 16 )
1085+ name_to_call = _hlo_call_name (func_name, module_suffix)
1086+
1087+ current_module = MLIR. IR. mmodule ()
1088+ top_level_block = MLIR. IR. body (current_module)
1089+
1090+ symbol_attr_name = String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ())
1091+
1092+ fn = MLIR. IR. lookup (
1093+ MLIR. IR. SymbolTable (MLIR. IR. Operation (current_module)), name_to_call
1094+ )
1095+ if isnothing (fn)
1096+ new_mod = parse (MLIR. IR. Module, code)
1097+ new_mod_op = MLIR. IR. Operation (new_mod)
1098+ body = MLIR. IR. body (new_mod)
1099+
1100+ operations = collect (MLIR. IR. OperationIterator (body))
1101+ for op in operations
1102+ if MLIR. IR. name (op) == " func.func"
1103+ fn_name = String (MLIR. IR. attr (op, symbol_attr_name))
1104+ if fn_name == func_name
1105+ fn = op
1106+ end
1107+
1108+ new_name = _hlo_call_name (fn_name, module_suffix)
1109+ res = MLIR. IR. LogicalResult (
1110+ MLIR. API. mlirSymbolTableReplaceAllSymbolUses (
1111+ fn_name, new_name, new_mod_op
1112+ ),
1113+ )
1114+ @assert res == MLIR. IR. success () " hlo_call: failed to rename $fn_name "
1115+
1116+ # Set function private
1117+ MLIR. IR. attr! (
1118+ op,
1119+ MLIR. API. mlirSymbolTableGetVisibilityAttributeName (),
1120+ MLIR. IR. Attribute (" private" ),
1121+ )
1122+
1123+ # Change function name
1124+ MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (new_name))
1125+ end
1126+ end
1127+
1128+ for op in operations
1129+ MLIR. IR. rmfromparent! (op)
1130+ push! (top_level_block, op)
1131+ end
1132+ end
1133+
1134+ if isnothing (fn)
1135+ error (" hlo_call: could not find function $func_name in the provided module" )
1136+ end
1137+
1138+ ftype_attr = MLIR. IR. attr (fn, " function_type" )
1139+ ftype = MLIR. IR. Type (ftype_attr)
1140+
1141+ @assert all (Base. Fix2 (isa, Reactant. AnyTracedRArray), args) " hlo_call: all inputs to hlo_call should be reactant arrays"
1142+ @assert MLIR. IR. ninputs (ftype) == length (args) " hlo_call: invalid number of arguments for function $func_name "
1143+
1144+ for (i, arg) in enumerate (args)
1145+ expected_type = MLIR. IR. input (ftype, i)
1146+ arg_type = MLIR. IR. type (arg. mlir_data)
1147+ @assert expected_type == arg_type " hlo_call: argument #$i has the wrong type (expected $expected_type , got $arg_type )"
1148+ end
1149+
1150+ operands = [a. mlir_data for a in args]
1151+ call = MLIR. Dialects. func. call (
1152+ operands;
1153+ result_0= [MLIR. IR. result (ftype, i) for i in 1 : MLIR. IR. nresults (ftype)],
1154+ callee= MLIR. IR. FlatSymbolRefAttribute (name_to_call),
1155+ location,
1156+ )
1157+
1158+ return ntuple (MLIR. IR. nresults (call)) do i
1159+ out = MLIR. IR. result (call, i)
1160+ ty = MLIR. IR. type (out)
1161+ sz = MLIR. IR. size (ty)
1162+ T = MLIR. IR. julia_type (eltype (ty))
1163+ N = length (sz)
1164+ if N == 0
1165+ Reactant. TracedRNumber {T} ((), out)
1166+ else
1167+ Reactant. TracedRArray {T,N} ((), out, sz)
1168+ end
1169+ end
1170+ end
1171+
1172+ end # module Ops
0 commit comments