Skip to content

Commit d707053

Browse files
committed
reorder
1 parent 1c5297c commit d707053

File tree

1 file changed

+45
-46
lines changed

1 file changed

+45
-46
lines changed

src/ProbProg.jl

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ function __init__()
5252
return nothing
5353
end
5454

55-
@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
56-
argprefix::Symbol = gensym("generatearg")
57-
resprefix::Symbol = gensym("generateresult")
58-
resargprefix::Symbol = gensym("generateresarg")
55+
@noinline function sample!(
56+
f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample")
57+
) where {Nargs}
58+
argprefix::Symbol = gensym("samplearg")
59+
resprefix::Symbol = gensym("sampleresult")
60+
resargprefix::Symbol = gensym("sampleresarg")
5961

6062
mlir_fn_res = invokelatest(
6163
TracedUtils.make_mlir_fn,
@@ -70,31 +72,45 @@ end
7072
resprefix,
7173
resargprefix,
7274
)
73-
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
75+
(; result, linear_args, linear_results) = mlir_fn_res
7476
fnwrap = mlir_fn_res.fnwrapped
7577
func2 = mlir_fn_res.f
7678

77-
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
78-
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
79-
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
80-
8179
batch_inputs = MLIR.IR.Value[]
8280
for a in linear_args
8381
idx, path = TracedUtils.get_argidx(a, argprefix)
8482
if idx == 1 && fnwrap
8583
TracedUtils.push_val!(batch_inputs, f, path[3:end])
8684
else
87-
if fnwrap
88-
idx -= 1
89-
end
85+
idx -= fnwrap ? 1 : 0
9086
TracedUtils.push_val!(batch_inputs, args[idx], path[3:end])
9187
end
9288
end
9389

94-
gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname)
90+
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
9591

92+
sym = TracedUtils.get_attribute_by_name(func2, "sym_name")
93+
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym))
94+
95+
traced_output_indices = Int[]
9696
for (i, res) in enumerate(linear_results)
97-
resv = MLIR.IR.result(gen_op, i)
97+
if TracedUtils.has_idx(res, resprefix)
98+
push!(traced_output_indices, i - 1)
99+
end
100+
end
101+
102+
symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol))
103+
104+
sample_op = MLIR.Dialects.enzyme.sample(
105+
batch_inputs;
106+
outputs=out_tys,
107+
fn=fn_attr,
108+
symbol=symbol_addr,
109+
traced_output_indices=traced_output_indices,
110+
)
111+
112+
for (i, res) in enumerate(linear_results)
113+
resv = MLIR.IR.result(sample_op, i)
98114
if TracedUtils.has_idx(res, resprefix)
99115
path = TracedUtils.get_idx(res, resprefix)
100116
TracedUtils.set!(result, path[2:end], resv)
@@ -116,12 +132,10 @@ end
116132
return result
117133
end
118134

119-
@noinline function sample!(
120-
f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample")
121-
) where {Nargs}
122-
argprefix::Symbol = gensym("samplearg")
123-
resprefix::Symbol = gensym("sampleresult")
124-
resargprefix::Symbol = gensym("sampleresarg")
135+
@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
136+
argprefix::Symbol = gensym("generatearg")
137+
resprefix::Symbol = gensym("generateresult")
138+
resargprefix::Symbol = gensym("generateresarg")
125139

126140
mlir_fn_res = invokelatest(
127141
TracedUtils.make_mlir_fn,
@@ -136,45 +150,31 @@ end
136150
resprefix,
137151
resargprefix,
138152
)
139-
(; result, linear_args, linear_results) = mlir_fn_res
153+
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
140154
fnwrap = mlir_fn_res.fnwrapped
141155
func2 = mlir_fn_res.f
142156

157+
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
158+
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
159+
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
160+
143161
batch_inputs = MLIR.IR.Value[]
144162
for a in linear_args
145163
idx, path = TracedUtils.get_argidx(a, argprefix)
146164
if idx == 1 && fnwrap
147165
TracedUtils.push_val!(batch_inputs, f, path[3:end])
148166
else
149-
idx -= fnwrap ? 1 : 0
167+
if fnwrap
168+
idx -= 1
169+
end
150170
TracedUtils.push_val!(batch_inputs, args[idx], path[3:end])
151171
end
152172
end
153173

154-
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
155-
156-
sym = TracedUtils.get_attribute_by_name(func2, "sym_name")
157-
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym))
158-
159-
traced_output_indices = Int[]
160-
for (i, res) in enumerate(linear_results)
161-
if TracedUtils.has_idx(res, resprefix)
162-
push!(traced_output_indices, i - 1)
163-
end
164-
end
165-
166-
symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol))
167-
168-
sample_op = MLIR.Dialects.enzyme.sample(
169-
batch_inputs;
170-
outputs=out_tys,
171-
fn=fn_attr,
172-
symbol=symbol_addr,
173-
traced_output_indices=traced_output_indices,
174-
)
174+
gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname)
175175

176176
for (i, res) in enumerate(linear_results)
177-
resv = MLIR.IR.result(sample_op, i)
177+
resv = MLIR.IR.result(gen_op, i)
178178
if TracedUtils.has_idx(res, resprefix)
179179
path = TracedUtils.get_idx(res, resprefix)
180180
TracedUtils.set!(result, path[2:end], resv)
@@ -278,5 +278,4 @@ function print_trace(trace::Dict{Symbol,Any})
278278
end
279279
return println("### End of Trace ###")
280280
end
281-
282-
end
281+
end

0 commit comments

Comments
 (0)