@@ -5,52 +5,6 @@ const enzyme_dupnoneed = 3
5
5
const enzyme_outnoneed = 4
6
6
const enzyme_constnoneed = 5
7
7
8
- function activate_strongzero! (strongzero:: Bool )
9
- stack = get! (task_local_storage (), :reactant_strongzero ) do
10
- Bool[]
11
- end
12
- push! (stack, strongzero)
13
- return nothing
14
- end
15
-
16
- function deactivate_strongzero! (strongzero:: Bool )
17
- key = :reactant_strongzero
18
- strongzero === last (task_local_storage (key)) ||
19
- error (" Deactivating wrong strong zerocontext" )
20
- return pop! (task_local_storage (key))
21
- end
22
-
23
- function get_strongzero ()
24
- key = :reactant_strongzero
25
- if ! (haskey (task_local_storage (), key) && ! Base. isempty (task_local_storage (key)))
26
- return false
27
- end
28
- return last (task_local_storage (key):: Vector{Bool} )
29
- end
30
-
31
- """
32
- @strongzero() begin
33
- # Derivative calls that require Enzyme to use string zeroing
34
- end
35
-
36
- Whether to enforce multiplication by zero as enforcing a zero result even if multiplying
37
- against a NaN or infinity. Necessary for some programs in which a value has a zero
38
- derivative since it is unused, even if it has an otherwise infinite or nan derivative.
39
-
40
- Outside of reactant this is equivalent to setting the global flag Enzyme.API.strong_zero!(true)
41
- before differentiation. This should be moved into the mode in both cases.
42
- """
43
- macro strongzero (ex)
44
- quote
45
- activate_strongzero! (true )
46
- try
47
- $ (esc (ex))
48
- finally
49
- deactivate_strongzero! (true )
50
- end
51
- end
52
- end
53
-
54
8
function Enzyme. make_zero (
55
9
:: Type{RT} , seen:: IdDict , prev:: RT , :: Val{copy_if_inactive} = Val (false )
56
10
):: RT where {copy_if_inactive,RT<: Union{RArray,RNumber} }
@@ -338,13 +292,9 @@ function overload_autodiff(
338
292
end
339
293
340
294
outtys = MLIR. IR. Type[]
341
- @inline needs_primal (:: Type{<:Enzyme.ReverseMode{ReturnPrimal}} ) where {ReturnPrimal} =
342
- ReturnPrimal
343
- @inline needs_primal (:: Type{<:Enzyme.ForwardMode{ReturnPrimal}} ) where {ReturnPrimal} =
344
- ReturnPrimal
345
295
for a in linear_results
346
296
if TracedUtils. has_idx (a, resprefix)
347
- if needs_primal (CMode)
297
+ if Enzyme . needs_primal (CMode)
348
298
push! (
349
299
outtys,
350
300
TracedUtils. transpose_ty (MLIR. IR. type (TracedUtils. get_mlir_data (a))),
@@ -389,7 +339,7 @@ function overload_autodiff(
389
339
ret_activity = Int32[]
390
340
for a in linear_results
391
341
if TracedUtils. has_idx (a, resprefix)
392
- act = act_from_type (A, reverse, needs_primal (CMode))
342
+ act = act_from_type (A, reverse, Enzyme . needs_primal (CMode))
393
343
push! (ret_activity, act)
394
344
if act == enzyme_out || act == enzyme_outnoneed
395
345
attr = MLIR. IR. DenseElementsAttribute (
@@ -440,7 +390,7 @@ function overload_autodiff(
440
390
outputs= outtys,
441
391
fn= fname,
442
392
width,
443
- strong_zero= get_strongzero ( ),
393
+ strong_zero= Enzyme . strong_zero (CMode ),
444
394
activity= MLIR. IR. Attribute ([act_attr (a) for a in activity]),
445
395
ret_activity= MLIR. IR. Attribute ([act_attr (a) for a in ret_activity]),
446
396
)
@@ -462,7 +412,7 @@ function overload_autodiff(
462
412
463
413
for a in linear_results
464
414
if TracedUtils. has_idx (a, resprefix)
465
- if needs_primal (CMode)
415
+ if Enzyme . needs_primal (CMode)
466
416
path = TracedUtils. get_idx (a, resprefix)
467
417
tval = TracedUtils. transpose_val (MLIR. IR. result (res, residx))
468
418
TracedUtils. set! (result, path[2 : end ], tval)
@@ -567,14 +517,14 @@ function overload_autodiff(
567
517
func2. operation = MLIR. API. MlirOperation (C_NULL )
568
518
569
519
if reverse
570
- resv = if needs_primal (CMode)
520
+ resv = if Enzyme . needs_primal (CMode)
571
521
result
572
522
else
573
523
nothing
574
524
end
575
525
return ((restup... ,), resv)
576
526
else
577
- if needs_primal (CMode)
527
+ if Enzyme . needs_primal (CMode)
578
528
if CMode <: Enzyme.ForwardMode && ! (A <: Enzyme.Const )
579
529
(dresult, result)
580
530
else
0 commit comments