Skip to content

Commit 7b3ad3c

Browse files
Regenerate MLIR Bindings (#1529)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 995be92 commit 7b3ad3c

File tree

1 file changed

+303
-1
lines changed

1 file changed

+303
-1
lines changed

src/mlir/Dialects/Enzyme.jl

Lines changed: 303 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,84 @@ import ...IR:
1313
import ..Dialects: namedattribute, operandsegmentsizes
1414
import ...API
1515

16+
"""
17+
`addRetvalToTrace`
18+
19+
Add the function\'s return value(s) into the execution trace.
20+
"""
21+
function addRetvalToTrace(
22+
trace::Value, retval::Vector{Value}; updated_trace::IR.Type, location=Location()
23+
)
24+
op_ty_results = IR.Type[updated_trace,]
25+
operands = Value[trace, retval...]
26+
owned_regions = Region[]
27+
successors = Block[]
28+
attributes = NamedAttribute[]
29+
30+
return create_operation(
31+
"enzyme.addRetvalToTrace",
32+
location;
33+
operands,
34+
owned_regions,
35+
successors,
36+
attributes,
37+
results=op_ty_results,
38+
result_inference=false,
39+
)
40+
end
41+
42+
"""
43+
`addSampleToTrace`
44+
45+
Add a sampled value into the execution trace.
46+
"""
47+
function addSampleToTrace(
48+
trace::Value, sample::Vector{Value}; updated_trace::IR.Type, symbol, location=Location()
49+
)
50+
op_ty_results = IR.Type[updated_trace,]
51+
operands = Value[trace, sample...]
52+
owned_regions = Region[]
53+
successors = Block[]
54+
attributes = NamedAttribute[namedattribute("symbol", symbol),]
55+
56+
return create_operation(
57+
"enzyme.addSampleToTrace",
58+
location;
59+
operands,
60+
owned_regions,
61+
successors,
62+
attributes,
63+
results=op_ty_results,
64+
result_inference=false,
65+
)
66+
end
67+
68+
"""
69+
`addSubtrace`
70+
71+
Insert a subtrace into a parent trace.
72+
"""
73+
function addSubtrace(
74+
subtrace::Value, trace::Value; updated_trace::IR.Type, symbol, location=Location()
75+
)
76+
op_ty_results = IR.Type[updated_trace,]
77+
operands = Value[subtrace, trace]
78+
owned_regions = Region[]
79+
successors = Block[]
80+
attributes = NamedAttribute[namedattribute("symbol", symbol),]
81+
82+
return create_operation(
83+
"enzyme.addSubtrace",
84+
location;
85+
operands,
86+
owned_regions,
87+
successors,
88+
attributes,
89+
results=op_ty_results,
90+
result_inference=false,
91+
)
92+
end
93+
1694
"""
1795
`addTo`
1896
@@ -37,6 +115,32 @@ function addTo(values::Vector{Value}; location=Location())
37115
)
38116
end
39117

118+
"""
119+
`addWeightToTrace`
120+
121+
Add the aggregated log-probability weight to the execution trace.
122+
"""
123+
function addWeightToTrace(
124+
trace::Value, weight::Value; updated_trace::IR.Type, location=Location()
125+
)
126+
op_ty_results = IR.Type[updated_trace,]
127+
operands = Value[trace, weight]
128+
owned_regions = Region[]
129+
successors = Block[]
130+
attributes = NamedAttribute[]
131+
132+
return create_operation(
133+
"enzyme.addWeightToTrace",
134+
location;
135+
operands,
136+
owned_regions,
137+
successors,
138+
attributes,
139+
results=op_ty_results,
140+
result_inference=false,
141+
)
142+
end
143+
40144
function autodiff(
41145
inputs::Vector{Value};
42146
outputs::Vector{IR.Type},
@@ -155,6 +259,49 @@ function fwddiff(
155259
)
156260
end
157261

262+
"""
263+
`generate`
264+
265+
Generate an execution trace and weight from a probabilistic function.
266+
If a `constraint` dict is provided AND the sample op\'s `symbol` is in the
267+
`constrained_symbols` array, we will use the corresponding constraint value
268+
instead of generating new samples from the probabilistic function.
269+
By convention, the 0th operand in `inputs` or `outputs` is the initial RNG
270+
state (seed).
271+
"""
272+
function generate(
273+
inputs::Vector{Value},
274+
constraint::Value;
275+
trace::IR.Type,
276+
weight::IR.Type,
277+
outputs::Vector{IR.Type},
278+
fn,
279+
constrained_addresses,
280+
name=nothing,
281+
location=Location(),
282+
)
283+
op_ty_results = IR.Type[trace, weight, outputs...]
284+
operands = Value[inputs..., constraint]
285+
owned_regions = Region[]
286+
successors = Block[]
287+
attributes = NamedAttribute[
288+
namedattribute("fn", fn),
289+
namedattribute("constrained_addresses", constrained_addresses),
290+
]
291+
!isnothing(name) && push!(attributes, namedattribute("name", name))
292+
293+
return create_operation(
294+
"enzyme.generate",
295+
location;
296+
operands,
297+
owned_regions,
298+
successors,
299+
attributes,
300+
results=op_ty_results,
301+
result_inference=false,
302+
)
303+
end
304+
158305
function genericAdjoint(
159306
inputs::Vector{Value},
160307
outputs::Vector{Value};
@@ -210,6 +357,58 @@ function get(gradient::Value; result_0::IR.Type, location=Location())
210357
)
211358
end
212359

360+
"""
361+
`getSampleFromConstraint`
362+
363+
Get sampled values from a constraint for a given symbol.
364+
"""
365+
function getSampleFromConstraint(
366+
constraint::Value; outputs::Vector{IR.Type}, symbol, location=Location()
367+
)
368+
op_ty_results = IR.Type[outputs...,]
369+
operands = Value[constraint,]
370+
owned_regions = Region[]
371+
successors = Block[]
372+
attributes = NamedAttribute[namedattribute("symbol", symbol),]
373+
374+
return create_operation(
375+
"enzyme.getSampleFromConstraint",
376+
location;
377+
operands,
378+
owned_regions,
379+
successors,
380+
attributes,
381+
results=op_ty_results,
382+
result_inference=false,
383+
)
384+
end
385+
386+
"""
387+
`getSubconstraint`
388+
389+
Get a subconstraint from a constraint for a given symbol.
390+
"""
391+
function getSubconstraint(
392+
constraint::Value; subconstraint::IR.Type, symbol, location=Location()
393+
)
394+
op_ty_results = IR.Type[subconstraint,]
395+
operands = Value[constraint,]
396+
owned_regions = Region[]
397+
successors = Block[]
398+
attributes = NamedAttribute[namedattribute("symbol", symbol),]
399+
400+
return create_operation(
401+
"enzyme.getSubconstraint",
402+
location;
403+
operands,
404+
owned_regions,
405+
successors,
406+
attributes,
407+
results=op_ty_results,
408+
result_inference=false,
409+
)
410+
end
411+
213412
function ignore_derivatives(input::Value; output::IR.Type, location=Location())
214413
op_ty_results = IR.Type[output,]
215414
operands = Value[input,]
@@ -248,6 +447,30 @@ function init(; result_0::IR.Type, location=Location())
248447
)
249448
end
250449

450+
"""
451+
`initTrace`
452+
453+
Initialize an execution trace for a probabilistic function.
454+
"""
455+
function initTrace(; trace::IR.Type, location=Location())
456+
op_ty_results = IR.Type[trace,]
457+
operands = Value[]
458+
owned_regions = Region[]
459+
successors = Block[]
460+
attributes = NamedAttribute[]
461+
462+
return create_operation(
463+
"enzyme.initTrace",
464+
location;
465+
operands,
466+
owned_regions,
467+
successors,
468+
attributes,
469+
results=op_ty_results,
470+
result_inference=false,
471+
)
472+
end
473+
251474
function placeholder(; output::IR.Type, location=Location())
252475
op_ty_results = IR.Type[output,]
253476
operands = Value[]
@@ -305,14 +528,28 @@ function push(cache::Value, value::Value; location=Location())
305528
)
306529
end
307530

531+
"""
532+
`sample`
533+
534+
Sample from a distribution. By convention, the 0th operand in `inputs`
535+
or `outputs` is the initial RNG state (seed).
536+
"""
308537
function sample(
309-
inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()
538+
inputs::Vector{Value};
539+
outputs::Vector{IR.Type},
540+
fn,
541+
logpdf=nothing,
542+
symbol=nothing,
543+
name=nothing,
544+
location=Location(),
310545
)
311546
op_ty_results = IR.Type[outputs...,]
312547
operands = Value[inputs...,]
313548
owned_regions = Region[]
314549
successors = Block[]
315550
attributes = NamedAttribute[namedattribute("fn", fn),]
551+
!isnothing(logpdf) && push!(attributes, namedattribute("logpdf", logpdf))
552+
!isnothing(symbol) && push!(attributes, namedattribute("symbol", symbol))
316553
!isnothing(name) && push!(attributes, namedattribute("name", name))
317554

318555
return create_operation(
@@ -346,4 +583,69 @@ function set(gradient::Value, value::Value; location=Location())
346583
)
347584
end
348585

586+
"""
587+
`simulate`
588+
589+
Simulate a probabilistic function to generate execution trace
590+
by replacing all SampleOps with distribution calls and recording
591+
all sampled values into the trace. This op returns the trace, the weight
592+
(accumulated log-probability), and the other outputs. By convention,
593+
the 0th operand in `inputs` or `outputs` is the initial RNG state (seed).
594+
"""
595+
function simulate(
596+
inputs::Vector{Value};
597+
trace::IR.Type,
598+
weight::IR.Type,
599+
outputs::Vector{IR.Type},
600+
fn,
601+
name=nothing,
602+
location=Location(),
603+
)
604+
op_ty_results = IR.Type[trace, weight, outputs...]
605+
operands = Value[inputs...,]
606+
owned_regions = Region[]
607+
successors = Block[]
608+
attributes = NamedAttribute[namedattribute("fn", fn),]
609+
!isnothing(name) && push!(attributes, namedattribute("name", name))
610+
611+
return create_operation(
612+
"enzyme.simulate",
613+
location;
614+
operands,
615+
owned_regions,
616+
successors,
617+
attributes,
618+
results=op_ty_results,
619+
result_inference=false,
620+
)
621+
end
622+
623+
"""
624+
`untracedCall`
625+
626+
Call a probabilistic function without tracing. By convention, the 0th operand in `inputs`
627+
or `outputs` is the initial RNG state (seed).
628+
"""
629+
function untracedCall(
630+
inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()
631+
)
632+
op_ty_results = IR.Type[outputs...,]
633+
operands = Value[inputs...,]
634+
owned_regions = Region[]
635+
successors = Block[]
636+
attributes = NamedAttribute[namedattribute("fn", fn),]
637+
!isnothing(name) && push!(attributes, namedattribute("name", name))
638+
639+
return create_operation(
640+
"enzyme.untracedCall",
641+
location;
642+
operands,
643+
owned_regions,
644+
successors,
645+
attributes,
646+
results=op_ty_results,
647+
result_inference=false,
648+
)
649+
end
650+
349651
end # enzyme

0 commit comments

Comments
 (0)