@@ -13,6 +13,84 @@ import ...IR:
13
13
import .. Dialects: namedattribute, operandsegmentsizes
14
14
import ... API
15
15
16
+ """
17
+ `addRetvalToTrace`
18
+
19
+ Add the function\' s return value(s) into the execution trace.
20
+ """
21
+ function addRetvalToTrace (
22
+ trace:: Value , retval:: Vector{Value} ; updated_trace:: IR.Type , location= Location ()
23
+ )
24
+ op_ty_results = IR. Type[updated_trace,]
25
+ operands = Value[trace, retval... ]
26
+ owned_regions = Region[]
27
+ successors = Block[]
28
+ attributes = NamedAttribute[]
29
+
30
+ return create_operation (
31
+ " enzyme.addRetvalToTrace" ,
32
+ location;
33
+ operands,
34
+ owned_regions,
35
+ successors,
36
+ attributes,
37
+ results= op_ty_results,
38
+ result_inference= false ,
39
+ )
40
+ end
41
+
42
+ """
43
+ `addSampleToTrace`
44
+
45
+ Add a sampled value into the execution trace.
46
+ """
47
+ function addSampleToTrace (
48
+ trace:: Value , sample:: Vector{Value} ; updated_trace:: IR.Type , symbol, location= Location ()
49
+ )
50
+ op_ty_results = IR. Type[updated_trace,]
51
+ operands = Value[trace, sample... ]
52
+ owned_regions = Region[]
53
+ successors = Block[]
54
+ attributes = NamedAttribute[namedattribute (" symbol" , symbol),]
55
+
56
+ return create_operation (
57
+ " enzyme.addSampleToTrace" ,
58
+ location;
59
+ operands,
60
+ owned_regions,
61
+ successors,
62
+ attributes,
63
+ results= op_ty_results,
64
+ result_inference= false ,
65
+ )
66
+ end
67
+
68
+ """
69
+ `addSubtrace`
70
+
71
+ Insert a subtrace into a parent trace.
72
+ """
73
+ function addSubtrace (
74
+ subtrace:: Value , trace:: Value ; updated_trace:: IR.Type , symbol, location= Location ()
75
+ )
76
+ op_ty_results = IR. Type[updated_trace,]
77
+ operands = Value[subtrace, trace]
78
+ owned_regions = Region[]
79
+ successors = Block[]
80
+ attributes = NamedAttribute[namedattribute (" symbol" , symbol),]
81
+
82
+ return create_operation (
83
+ " enzyme.addSubtrace" ,
84
+ location;
85
+ operands,
86
+ owned_regions,
87
+ successors,
88
+ attributes,
89
+ results= op_ty_results,
90
+ result_inference= false ,
91
+ )
92
+ end
93
+
16
94
"""
17
95
`addTo`
18
96
@@ -37,6 +115,32 @@ function addTo(values::Vector{Value}; location=Location())
37
115
)
38
116
end
39
117
118
+ """
119
+ `addWeightToTrace`
120
+
121
+ Add the aggregated log-probability weight to the execution trace.
122
+ """
123
+ function addWeightToTrace (
124
+ trace:: Value , weight:: Value ; updated_trace:: IR.Type , location= Location ()
125
+ )
126
+ op_ty_results = IR. Type[updated_trace,]
127
+ operands = Value[trace, weight]
128
+ owned_regions = Region[]
129
+ successors = Block[]
130
+ attributes = NamedAttribute[]
131
+
132
+ return create_operation (
133
+ " enzyme.addWeightToTrace" ,
134
+ location;
135
+ operands,
136
+ owned_regions,
137
+ successors,
138
+ attributes,
139
+ results= op_ty_results,
140
+ result_inference= false ,
141
+ )
142
+ end
143
+
40
144
function autodiff (
41
145
inputs:: Vector{Value} ;
42
146
outputs:: Vector{IR.Type} ,
@@ -155,6 +259,49 @@ function fwddiff(
155
259
)
156
260
end
157
261
262
+ """
263
+ `generate`
264
+
265
+ Generate an execution trace and weight from a probabilistic function.
266
+ If a `constraint` dict is provided AND the sample op\' s `symbol` is in the
267
+ `constrained_symbols` array, we will use the corresponding constraint value
268
+ instead of generating new samples from the probabilistic function.
269
+ By convention, the 0th operand in `inputs` or `outputs` is the initial RNG
270
+ state (seed).
271
+ """
272
+ function generate (
273
+ inputs:: Vector{Value} ,
274
+ constraint:: Value ;
275
+ trace:: IR.Type ,
276
+ weight:: IR.Type ,
277
+ outputs:: Vector{IR.Type} ,
278
+ fn,
279
+ constrained_addresses,
280
+ name= nothing ,
281
+ location= Location (),
282
+ )
283
+ op_ty_results = IR. Type[trace, weight, outputs... ]
284
+ operands = Value[inputs... , constraint]
285
+ owned_regions = Region[]
286
+ successors = Block[]
287
+ attributes = NamedAttribute[
288
+ namedattribute (" fn" , fn),
289
+ namedattribute (" constrained_addresses" , constrained_addresses),
290
+ ]
291
+ ! isnothing (name) && push! (attributes, namedattribute (" name" , name))
292
+
293
+ return create_operation (
294
+ " enzyme.generate" ,
295
+ location;
296
+ operands,
297
+ owned_regions,
298
+ successors,
299
+ attributes,
300
+ results= op_ty_results,
301
+ result_inference= false ,
302
+ )
303
+ end
304
+
158
305
function genericAdjoint (
159
306
inputs:: Vector{Value} ,
160
307
outputs:: Vector{Value} ;
@@ -210,6 +357,58 @@ function get(gradient::Value; result_0::IR.Type, location=Location())
210
357
)
211
358
end
212
359
360
+ """
361
+ `getSampleFromConstraint`
362
+
363
+ Get sampled values from a constraint for a given symbol.
364
+ """
365
+ function getSampleFromConstraint (
366
+ constraint:: Value ; outputs:: Vector{IR.Type} , symbol, location= Location ()
367
+ )
368
+ op_ty_results = IR. Type[outputs... ,]
369
+ operands = Value[constraint,]
370
+ owned_regions = Region[]
371
+ successors = Block[]
372
+ attributes = NamedAttribute[namedattribute (" symbol" , symbol),]
373
+
374
+ return create_operation (
375
+ " enzyme.getSampleFromConstraint" ,
376
+ location;
377
+ operands,
378
+ owned_regions,
379
+ successors,
380
+ attributes,
381
+ results= op_ty_results,
382
+ result_inference= false ,
383
+ )
384
+ end
385
+
386
+ """
387
+ `getSubconstraint`
388
+
389
+ Get a subconstraint from a constraint for a given symbol.
390
+ """
391
+ function getSubconstraint (
392
+ constraint:: Value ; subconstraint:: IR.Type , symbol, location= Location ()
393
+ )
394
+ op_ty_results = IR. Type[subconstraint,]
395
+ operands = Value[constraint,]
396
+ owned_regions = Region[]
397
+ successors = Block[]
398
+ attributes = NamedAttribute[namedattribute (" symbol" , symbol),]
399
+
400
+ return create_operation (
401
+ " enzyme.getSubconstraint" ,
402
+ location;
403
+ operands,
404
+ owned_regions,
405
+ successors,
406
+ attributes,
407
+ results= op_ty_results,
408
+ result_inference= false ,
409
+ )
410
+ end
411
+
213
412
function ignore_derivatives (input:: Value ; output:: IR.Type , location= Location ())
214
413
op_ty_results = IR. Type[output,]
215
414
operands = Value[input,]
@@ -248,6 +447,30 @@ function init(; result_0::IR.Type, location=Location())
248
447
)
249
448
end
250
449
450
+ """
451
+ `initTrace`
452
+
453
+ Initialize an execution trace for a probabilistic function.
454
+ """
455
+ function initTrace (; trace:: IR.Type , location= Location ())
456
+ op_ty_results = IR. Type[trace,]
457
+ operands = Value[]
458
+ owned_regions = Region[]
459
+ successors = Block[]
460
+ attributes = NamedAttribute[]
461
+
462
+ return create_operation (
463
+ " enzyme.initTrace" ,
464
+ location;
465
+ operands,
466
+ owned_regions,
467
+ successors,
468
+ attributes,
469
+ results= op_ty_results,
470
+ result_inference= false ,
471
+ )
472
+ end
473
+
251
474
function placeholder (; output:: IR.Type , location= Location ())
252
475
op_ty_results = IR. Type[output,]
253
476
operands = Value[]
@@ -305,14 +528,28 @@ function push(cache::Value, value::Value; location=Location())
305
528
)
306
529
end
307
530
531
+ """
532
+ `sample`
533
+
534
+ Sample from a distribution. By convention, the 0th operand in `inputs`
535
+ or `outputs` is the initial RNG state (seed).
536
+ """
308
537
function sample (
309
- inputs:: Vector{Value} ; outputs:: Vector{IR.Type} , fn, name= nothing , location= Location ()
538
+ inputs:: Vector{Value} ;
539
+ outputs:: Vector{IR.Type} ,
540
+ fn,
541
+ logpdf= nothing ,
542
+ symbol= nothing ,
543
+ name= nothing ,
544
+ location= Location (),
310
545
)
311
546
op_ty_results = IR. Type[outputs... ,]
312
547
operands = Value[inputs... ,]
313
548
owned_regions = Region[]
314
549
successors = Block[]
315
550
attributes = NamedAttribute[namedattribute (" fn" , fn),]
551
+ ! isnothing (logpdf) && push! (attributes, namedattribute (" logpdf" , logpdf))
552
+ ! isnothing (symbol) && push! (attributes, namedattribute (" symbol" , symbol))
316
553
! isnothing (name) && push! (attributes, namedattribute (" name" , name))
317
554
318
555
return create_operation (
@@ -346,4 +583,69 @@ function set(gradient::Value, value::Value; location=Location())
346
583
)
347
584
end
348
585
586
+ """
587
+ `simulate`
588
+
589
+ Simulate a probabilistic function to generate execution trace
590
+ by replacing all SampleOps with distribution calls and recording
591
+ all sampled values into the trace. This op returns the trace, the weight
592
+ (accumulated log-probability), and the other outputs. By convention,
593
+ the 0th operand in `inputs` or `outputs` is the initial RNG state (seed).
594
+ """
595
+ function simulate (
596
+ inputs:: Vector{Value} ;
597
+ trace:: IR.Type ,
598
+ weight:: IR.Type ,
599
+ outputs:: Vector{IR.Type} ,
600
+ fn,
601
+ name= nothing ,
602
+ location= Location (),
603
+ )
604
+ op_ty_results = IR. Type[trace, weight, outputs... ]
605
+ operands = Value[inputs... ,]
606
+ owned_regions = Region[]
607
+ successors = Block[]
608
+ attributes = NamedAttribute[namedattribute (" fn" , fn),]
609
+ ! isnothing (name) && push! (attributes, namedattribute (" name" , name))
610
+
611
+ return create_operation (
612
+ " enzyme.simulate" ,
613
+ location;
614
+ operands,
615
+ owned_regions,
616
+ successors,
617
+ attributes,
618
+ results= op_ty_results,
619
+ result_inference= false ,
620
+ )
621
+ end
622
+
623
+ """
624
+ `untracedCall`
625
+
626
+ Call a probabilistic function without tracing. By convention, the 0th operand in `inputs`
627
+ or `outputs` is the initial RNG state (seed).
628
+ """
629
+ function untracedCall (
630
+ inputs:: Vector{Value} ; outputs:: Vector{IR.Type} , fn, name= nothing , location= Location ()
631
+ )
632
+ op_ty_results = IR. Type[outputs... ,]
633
+ operands = Value[inputs... ,]
634
+ owned_regions = Region[]
635
+ successors = Block[]
636
+ attributes = NamedAttribute[namedattribute (" fn" , fn),]
637
+ ! isnothing (name) && push! (attributes, namedattribute (" name" , name))
638
+
639
+ return create_operation (
640
+ " enzyme.untracedCall" ,
641
+ location;
642
+ operands,
643
+ owned_regions,
644
+ successors,
645
+ attributes,
646
+ results= op_ty_results,
647
+ result_inference= false ,
648
+ )
649
+ end
650
+
349
651
end # enzyme
0 commit comments