Skip to content

Commit 1c0668e

Browse files
wsmosesTestHit
andauthored
fix windows (#2581)
* fix * fix --------- Co-authored-by: William S. Moses <[email protected]>
1 parent 25836d0 commit 1c0668e

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

src/compiler.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,6 +1997,17 @@ end
19971997
include("rules/allocrules.jl")
19981998
include("rules/llvmrules.jl")
19991999

2000+
function add_one_in_place(x)
2001+
if x isa Base.RefValue
2002+
x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x))))
2003+
elseif x isa (Array{T,0} where T)
2004+
x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x))))
2005+
else
2006+
throw(EnzymeNonScalarReturnException(x, ""))
2007+
end
2008+
return nothing
2009+
end
2010+
20002011
for (k, v) in (
20012012
("enz_runtime_newtask_fwd", Enzyme.Compiler.runtime_newtask_fwd),
20022013
("enz_runtime_newtask_augfwd", Enzyme.Compiler.runtime_newtask_augfwd),
@@ -2018,6 +2029,7 @@ for (k, v) in (
20182029
("enz_runtime_jl_setfield_rev", Enzyme.Compiler.rt_jl_setfield_rev),
20192030
("enz_runtime_error_if_differentiable", Enzyme.Compiler.error_if_differentiable),
20202031
("enz_runtime_error_if_active", Enzyme.Compiler.error_if_active),
2032+
("enz_add_one_in_place", Enzyme.Compiler.add_one_in_place),
20212033
)
20222034
JuliaEnzymeNameMap[k] = v
20232035
end
@@ -5072,7 +5084,7 @@ end
50725084
if !(primal_target isa GPUCompiler.NativeCompilerTarget)
50735085
reinsert_gcmarker!(adjointf)
50745086
augmented_primalf !== nothing && reinsert_gcmarker!(augmented_primalf)
5075-
post_optimze!(mod, target_machine, false) #=machine=#
5087+
post_optimize!(mod, target_machine, false) #=machine=#
50765088
end
50775089

50785090
adjointf = functions(mod)[adjointf_name]
@@ -5236,17 +5248,6 @@ include("typeutils/recursive_add.jl")
52365248
end
52375249
end
52385250

5239-
function add_one_in_place(x)
5240-
if x isa Base.RefValue
5241-
x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x))))
5242-
elseif x isa (Array{T,0} where T)
5243-
x[] = recursive_add(x[], default_adjoint(eltype(Core.Typeof(x))))
5244-
else
5245-
throw(EnzymeNonScalarReturnException(x, ""))
5246-
end
5247-
return nothing
5248-
end
5249-
52505251
@generated function enzyme_call(
52515252
::Val{RawCall},
52525253
fptr::PT,
@@ -5814,7 +5815,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
58145815
if DumpPrePostOpt[]
58155816
API.EnzymeDumpModuleRef(mod.ref)
58165817
end
5817-
post_optimze!(mod, JIT.get_tm())
5818+
post_optimize!(mod, JIT.get_tm())
58185819
if DumpPostOpt[]
58195820
API.EnzymeDumpModuleRef(mod.ref)
58205821
end

src/compiler/optimize.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ function addJuliaLegalizationPasses!(pm::LLVM.ModulePassManager, tm::LLVM.Target
718718
end
719719
end
720720

721-
function post_optimze!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
721+
function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
722722
addr13NoAlias(mod)
723723
removeDeadArgs!(mod, tm)
724724
for f in collect(functions(mod))
@@ -764,6 +764,14 @@ function post_optimze!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool =
764764
LLVM.run!(pm, mod)
765765
end
766766
end
767+
for f in functions(mod)
768+
if isempty(blocks(f))
769+
continue
770+
end
771+
if !has_fn_attr(f, StringAttribute("frame-pointer"))
772+
push!(function_attributes(f), StringAttribute("frame-pointer", "all"))
773+
end
774+
end
767775
# @safe_show "post_mod", mod
768776
# flush(stdout)
769777
# flush(stderr)

src/compiler/reflection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ function reflect(
7474
mod, meta = GPUCompiler.codegen(:llvm, job) #= validate=false =#
7575

7676
if second_stage
77-
post_optimze!(mod, JIT.get_tm())
77+
post_optimize!(mod, JIT.get_tm())
7878
end
7979

8080
llvmf = meta.adjointf

0 commit comments

Comments
 (0)