@@ -52,10 +52,12 @@ function __init__()
52
52
return nothing
53
53
end
54
54
55
- @noinline function generate! (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
56
- argprefix:: Symbol = gensym (" generatearg" )
57
- resprefix:: Symbol = gensym (" generateresult" )
58
- resargprefix:: Symbol = gensym (" generateresarg" )
55
+ @noinline function sample! (
56
+ f:: Function , args:: Vararg{Any,Nargs} ; symbol:: Symbol = gensym (" sample" )
57
+ ) where {Nargs}
58
+ argprefix:: Symbol = gensym (" samplearg" )
59
+ resprefix:: Symbol = gensym (" sampleresult" )
60
+ resargprefix:: Symbol = gensym (" sampleresarg" )
59
61
60
62
mlir_fn_res = invokelatest (
61
63
TracedUtils. make_mlir_fn,
70
72
resprefix,
71
73
resargprefix,
72
74
)
73
- (; result, linear_args, in_tys, linear_results) = mlir_fn_res
75
+ (; result, linear_args, linear_results) = mlir_fn_res
74
76
fnwrap = mlir_fn_res. fnwrapped
75
77
func2 = mlir_fn_res. f
76
78
77
- out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
78
- fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
79
- fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
80
-
81
79
batch_inputs = MLIR. IR. Value[]
82
80
for a in linear_args
83
81
idx, path = TracedUtils. get_argidx (a, argprefix)
84
82
if idx == 1 && fnwrap
85
83
TracedUtils. push_val! (batch_inputs, f, path[3 : end ])
86
84
else
87
- if fnwrap
88
- idx -= 1
89
- end
85
+ idx -= fnwrap ? 1 : 0
90
86
TracedUtils. push_val! (batch_inputs, args[idx], path[3 : end ])
91
87
end
92
88
end
93
89
94
- gen_op = MLIR. Dialects . enzyme . generate (batch_inputs; outputs = out_tys, fn = fname)
90
+ out_tys = [ MLIR. IR . type (TracedUtils . get_mlir_data (res)) for res in linear_results]
95
91
92
+ sym = TracedUtils. get_attribute_by_name (func2, " sym_name" )
93
+ fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (sym))
94
+
95
+ traced_output_indices = Int[]
96
96
for (i, res) in enumerate (linear_results)
97
- resv = MLIR. IR. result (gen_op, i)
97
+ if TracedUtils. has_idx (res, resprefix)
98
+ push! (traced_output_indices, i - 1 )
99
+ end
100
+ end
101
+
102
+ symbol_addr = reinterpret (UInt64, pointer_from_objref (symbol))
103
+
104
+ sample_op = MLIR. Dialects. enzyme. sample (
105
+ batch_inputs;
106
+ outputs= out_tys,
107
+ fn= fn_attr,
108
+ symbol= symbol_addr,
109
+ traced_output_indices= traced_output_indices,
110
+ )
111
+
112
+ for (i, res) in enumerate (linear_results)
113
+ resv = MLIR. IR. result (sample_op, i)
98
114
if TracedUtils. has_idx (res, resprefix)
99
115
path = TracedUtils. get_idx (res, resprefix)
100
116
TracedUtils. set! (result, path[2 : end ], resv)
@@ -116,12 +132,10 @@ end
116
132
return result
117
133
end
118
134
119
- @noinline function sample! (
120
- f:: Function , args:: Vararg{Any,Nargs} ; symbol:: Symbol = gensym (" sample" )
121
- ) where {Nargs}
122
- argprefix:: Symbol = gensym (" samplearg" )
123
- resprefix:: Symbol = gensym (" sampleresult" )
124
- resargprefix:: Symbol = gensym (" sampleresarg" )
135
+ @noinline function generate! (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
136
+ argprefix:: Symbol = gensym (" generatearg" )
137
+ resprefix:: Symbol = gensym (" generateresult" )
138
+ resargprefix:: Symbol = gensym (" generateresarg" )
125
139
126
140
mlir_fn_res = invokelatest (
127
141
TracedUtils. make_mlir_fn,
@@ -136,45 +150,31 @@ end
136
150
resprefix,
137
151
resargprefix,
138
152
)
139
- (; result, linear_args, linear_results) = mlir_fn_res
153
+ (; result, linear_args, in_tys, linear_results) = mlir_fn_res
140
154
fnwrap = mlir_fn_res. fnwrapped
141
155
func2 = mlir_fn_res. f
142
156
157
+ out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
158
+ fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
159
+ fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
160
+
143
161
batch_inputs = MLIR. IR. Value[]
144
162
for a in linear_args
145
163
idx, path = TracedUtils. get_argidx (a, argprefix)
146
164
if idx == 1 && fnwrap
147
165
TracedUtils. push_val! (batch_inputs, f, path[3 : end ])
148
166
else
149
- idx -= fnwrap ? 1 : 0
167
+ if fnwrap
168
+ idx -= 1
169
+ end
150
170
TracedUtils. push_val! (batch_inputs, args[idx], path[3 : end ])
151
171
end
152
172
end
153
173
154
- out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
155
-
156
- sym = TracedUtils. get_attribute_by_name (func2, " sym_name" )
157
- fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (sym))
158
-
159
- traced_output_indices = Int[]
160
- for (i, res) in enumerate (linear_results)
161
- if TracedUtils. has_idx (res, resprefix)
162
- push! (traced_output_indices, i - 1 )
163
- end
164
- end
165
-
166
- symbol_addr = reinterpret (UInt64, pointer_from_objref (symbol))
167
-
168
- sample_op = MLIR. Dialects. enzyme. sample (
169
- batch_inputs;
170
- outputs= out_tys,
171
- fn= fn_attr,
172
- symbol= symbol_addr,
173
- traced_output_indices= traced_output_indices,
174
- )
174
+ gen_op = MLIR. Dialects. enzyme. generate (batch_inputs; outputs= out_tys, fn= fname)
175
175
176
176
for (i, res) in enumerate (linear_results)
177
- resv = MLIR. IR. result (sample_op , i)
177
+ resv = MLIR. IR. result (gen_op , i)
178
178
if TracedUtils. has_idx (res, resprefix)
179
179
path = TracedUtils. get_idx (res, resprefix)
180
180
TracedUtils. set! (result, path[2 : end ], resv)
@@ -278,5 +278,4 @@ function print_trace(trace::Dict{Symbol,Any})
278
278
end
279
279
return println (" ### End of Trace ###" )
280
280
end
281
-
282
- end
281
+ end
0 commit comments