Skip to content

Commit 72d92fb

Browse files
committed
Add tfunc support for copysign and ifelse
Fixes #18
1 parent 692a134 commit 72d92fb

File tree

7 files changed

+69
-12
lines changed

7 files changed

+69
-12
lines changed

src/analysis/ipoincidence.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ function _make_argument_lattice_elem(𝕃, which::Argument, @nospecialize(argt),
167167
return argt
168168
elseif Compiler.isprimitivetype(argt)
169169
inc = Incidence(add_variable!(which))
170-
return argt === Float64 ? inc : Incidence(argt, inc.row, inc.eps)
170+
return argt === Float64 ? inc : Incidence(argt, inc.row)
171171
elseif isa(argt, PartialStruct)
172172
return PartialStruct(𝕃, argt.typ, Any[make_argument_lattice_elem(𝕃, which, f, add_variable!, add_equation!, add_scope!) for f in argt.fields])
173173
elseif isabstracttype(argt) || ismutabletype(argt) || !isa(argt, DataType)

src/analysis/lattice.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ function Compiler.tmeet(🥬::EqStructureLattice, @nospecialize(a), @nospecializ
430430
meet = Compiler.tmeet(Compiler.widenlattice(🥬), a.typ, b)
431431
meet == Union{} && return Union{}
432432
Base.issingletontype(meet) && return meet
433-
return Incidence(meet, copy(a.row), copy(a.eps))
433+
return Incidence(meet, copy(a.row))
434434
elseif isa(a, Eq)
435435
meet = Compiler.tmeet(Compiler.widenlattice(🥬), equation, b)
436436
meet == Union{} && return Union{}
@@ -514,15 +514,15 @@ function Compiler.tmerge(🥬::EqStructureLattice, @nospecialize(a), @nospeciali
514514
row[i] = nonlinear
515515
end
516516
end
517-
return Incidence(merged_typ, row, union(a.eps, b.eps))
517+
return Incidence(merged_typ, row)
518518
elseif isa(b, Const)
519519
# Const has no incidence taint
520520
typ = Compiler.tmerge(Compiler.widenlattice(🥬), a.typ, b)
521521
r = copy(a)
522522
for i in rowvals(r.row)
523523
r.row[i] = nonlinear
524524
end
525-
return Incidence(typ, r.row, copy(a.eps))
525+
return Incidence(typ, r.row)
526526
else
527527
a = widenconst(a)
528528
end

src/analysis/refiner.jl

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,22 @@ function tfunc(::Val{Core.Intrinsics.mul_float}, @nospecialize(a::Union{Const, T
184184
return Incidence(builtin_math_tfunc(Core.Intrinsics.mul_float, a.typ, b.typ), rrow)
185185
end
186186

187+
function tfunc(::Val{Core.Intrinsics.copysign_float}, @nospecialize(a::Union{Const, Type{Float64}, Incidence}), @nospecialize(b::Union{Const, Type{Float64}, Incidence}))
188+
if a === Float64 || b === Float64
189+
return Float64
190+
end
191+
if isa(a, Const) && isa(b, Const)
192+
return builtin_math_tfunc(Core.Intrinsics.copysign_float, a, b)
193+
end
194+
rrow = _zero_row()
195+
arow = isa(a, Incidence) ? a.row : _ZERO_ROW
196+
brow = isa(b, Incidence) ? b.row : _ZERO_ROW
197+
for i in Iterators.flatten((rowvals(arow), rowvals(brow)))
198+
rrow[i] = nonlinear
199+
end
200+
return Incidence(builtin_math_tfunc(Core.Intrinsics.copysign_float, widenconst(a), widenconst(b)), rrow)
201+
end
202+
187203
function tfunc(::Val{Core.Intrinsics.div_float}, @nospecialize(a::Union{Const, Type{Float64}, Incidence}), @nospecialize(b::Union{Const, Type{Float64}, Incidence}))
188204
if isa(a, Const) && isa(b, Const)
189205
return builtin_math_tfunc(Core.Intrinsics.div_float, a, b)
@@ -258,18 +274,19 @@ is_any_incidence(@nospecialize args...) = any(@nospecialize(x)->isa(x, Incidence
258274
b = argtypes[2]
259275
if is_any_incidence(a, b)
260276
if (f == Core.Intrinsics.add_float || f == Core.Intrinsics.sub_float) ||
261-
(f == Core.Intrinsics.mul_float || f == Core.Intrinsics.div_float)
277+
(f == Core.Intrinsics.mul_float || f == Core.Intrinsics.div_float) ||
278+
f == Core.Intrinsics.copysign_float
262279
return tfunc(Val(f), a, b)
263280
elseif f == Core.Intrinsics.lt_float
264-
r = tmerge(typeinf_lattice(interp), argtypes[1], argtypes[2])
281+
r = Compiler.tmerge(Compiler.typeinf_lattice(interp), argtypes[1], argtypes[2])
265282
@assert isa(r, Incidence)
266-
return Incidence(Bool, r.row, r.eps)
283+
return Incidence(Bool, r.row)
267284
elseif f === Core.getfield && isa(a, Incidence)
268285
a = argtypes[1]
269-
fT = getfield_tfunc(typeinf_lattice(interp), widenconst(a), argtypes[2])
286+
fT = Compiler.getfield_tfunc(Compiler.typeinf_lattice(interp), widenconst(a), argtypes[2])
270287
fT === Union{} && return Union{}
271288
Base.issingletontype(fT) && return fT
272-
return Incidence(fT, copy(a.row), copy(a.eps))
289+
return Incidence(fT, copy(a.row))
273290
end
274291
end
275292
elseif length(argtypes) == 3
@@ -297,7 +314,19 @@ is_any_incidence(@nospecialize args...) = any(@nospecialize(x)->isa(x, Incidence
297314
return c
298315
end
299316
end
300-
# TODO: tmergea
317+
rt = Compiler.tmerge(Compiler.typeinf_lattice(interp), b, c)
318+
if isa(rt, Incidence)
319+
if isa(a, Incidence)
320+
rrow = copy(rt.row)
321+
for i in rowvals(a.row)
322+
rrow[i] = nonlinear
323+
end
324+
rt = Incidence(rt.typ, rrow)
325+
else
326+
rt = widenconst(rt)
327+
end
328+
end
329+
return rt
301330
end
302331
end
303332
end

src/transform/codegen/init_uncompress.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ function gen_init_uncompress!(
120120

121121
if is_known_invoke(stmt, variable, ir) || is_equation_call(stmt, ir)
122122
display(ir)
123+
@show stmt
123124
error()
124125
elseif is_known_invoke(stmt, equation, ir)
125126
# Equation - used, but only as an arg to equation call, which will all get

src/transform/tearing/schedule.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ function schedule_nonlinear!(compact, param_vars, var_eq_matching, ir, val::Unio
168168
f = argextype(stmt.args[1], ir)
169169
@assert isa(f, Const)
170170
f = f.val
171-
@assert f in (Core.Intrinsics.sub_float, Core.Intrinsics.add_float, Core.Intrinsics.mul_float)
171+
@assert f in (Core.Intrinsics.sub_float, Core.Intrinsics.add_float,
172+
Core.Intrinsics.mul_float, Core.Intrinsics.copysign_float,
173+
Core.ifelse)
172174
call_is_linear = f in (Core.Intrinsics.sub_float, Core.Intrinsics.add_float)
173175
end
174176

test/regression.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
module Regression
2+
3+
using Test
4+
using DAECompiler
5+
using DAECompiler.Intrinsics
6+
using Sundials
7+
using SciMLBase
8+
using OrdinaryDiffEq
9+
10+
const cf = Base.copysign_float
11+
const ief = Core.ifelse
12+
const -= Core.Intrinsics.sub_float
13+
14+
function tfb1()
15+
x = continuous()
16+
b = (x < 43200.)
17+
v = ief(b, x, cf(0., x))
18+
always!(ddt(x) -ᵢ v)
19+
end
20+
21+
sol = solve(DAECProblem(tfb1, (1,) .=> 1.), IDA())
22+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], exp.(sol.t)))
23+
24+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
include("basic.jl")
22
include("ipo.jl")
3-
include("ssrm.jl")
3+
include("ssrm.jl")
4+
include("regression.jl")

0 commit comments

Comments
 (0)