Skip to content

Commit 927a97a

Browse files
committed
Improve intrinsic coverage
Fixes #19.
1 parent 88454ee commit 927a97a

File tree

4 files changed

+96
-30
lines changed

4 files changed

+96
-30
lines changed

src/analysis/refiner.jl

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ function structural_inc_ddt(var_to_diff::DiffGraph, varclassification::Union{Vec
118118
end
119119

120120
#==================== Base Math Intrinsic Refinement Models ===========================#
121-
function tfunc(::Val{Core.Intrinsics.neg_float}, @nospecialize(a::Union{Const, Incidence}))
121+
function tfunc(F::Union{Val{Core.Intrinsics.neg_float}, Val{Core.Intrinsics.neg_int}}, @nospecialize(a::Union{Const, Incidence}))
122122
if isa(a, Incidence)
123123
arow = copy(a.row)
124124
for (i, v) in zip(rowvals(a.row), nonzeros(a.row))
@@ -127,36 +127,36 @@ function tfunc(::Val{Core.Intrinsics.neg_float}, @nospecialize(a::Union{Const, I
127127
else
128128
arow = _zero_row()
129129
end
130-
return Incidence(builtin_math_tfunc(Core.Intrinsics.neg_float, isa(a, Incidence) ? a.typ : a), arow)
130+
return Incidence(builtin_math_tfunc(typeof(F).parameters[1], isa(a, Incidence) ? a.typ : a), arow)
131131
end
132132

133133
get_eps(inc::Incidence) = inc.eps
134134
get_eps(c::Const) = BitSet()
135135
get_eps(::Type) = error()
136136

137-
function tfunc(::Val{Core.Intrinsics.add_float}, @nospecialize(a::Union{Const, Type{Float64}, Incidence}), @nospecialize(b::Union{Const, Type{Float64}, Incidence}))
137+
function tfunc(F::Union{Val{Core.Intrinsics.add_float}, Val{Core.Intrinsics.add_int}}, @nospecialize(a::Union{Const, Type{Float64}, Incidence}), @nospecialize(b::Union{Const, Type{Float64}, Incidence}))
138138
if a === Float64 || b === Float64
139139
return Float64
140140
end
141-
isa(a, Const) && isa(b, Const) && return builtin_math_tfunc(Core.Intrinsics.add_float, a, b)
141+
isa(a, Const) && isa(b, Const) && return builtin_math_tfunc(typeof(F).parameters[1], a, b)
142142
arow = isa(a, Incidence) ? a.row : _ZERO_ROW
143143
brow = isa(b, Incidence) ? b.row : _ZERO_ROW
144144
rrow = copy(arow) .+= brow
145-
const_val = builtin_math_tfunc(Core.Intrinsics.add_float, isa(a, Incidence) ? a.typ : a, isa(b, Incidence) ? b.typ : b)
145+
const_val = builtin_math_tfunc(typeof(F).parameters[1], isa(a, Incidence) ? a.typ : a, isa(b, Incidence) ? b.typ : b)
146146
if isa(const_val, Const) && !any(!iszero, rrow)
147147
return const_val
148148
end
149149
return Incidence(const_val, rrow)
150150
end
151151

152-
function tfunc(::Val{Core.Intrinsics.sub_float}, @nospecialize(a::Union{Const, Incidence}), @nospecialize(b::Union{Const, Incidence}))
153-
isa(a, Const) && isa(b, Const) && return builtin_math_tfunc(Core.Intrinsics.sub_float, a, b)
152+
function tfunc(F::Union{Val{Core.Intrinsics.sub_float}, Val{Core.Intrinsics.sub_int}}, @nospecialize(a::Union{Const, Incidence}), @nospecialize(b::Union{Const, Incidence}))
153+
isa(a, Const) && isa(b, Const) && return builtin_math_tfunc(typeof(F).parameters[1], a, b)
154154
arow = isa(a, Incidence) ? a.row : _ZERO_ROW
155155
brow = isa(b, Incidence) ? b.row : _ZERO_ROW
156156
# return Incidence(a.row + b.row), but see https://github.com/JuliaArrays/OffsetArrays.jl/issues/299
157157
# and https://github.com/JuliaSparse/SparseArrays.jl/issues/101
158158
rrow = copy(arow) .-= brow
159-
const_val = builtin_math_tfunc(Core.Intrinsics.sub_float, isa(a, Incidence) ? a.typ : a, isa(b, Incidence) ? b.typ : b)
159+
const_val = builtin_math_tfunc(typeof(F).parameters[1], isa(a, Incidence) ? a.typ : a, isa(b, Incidence) ? b.typ : b)
160160
if isa(const_val, Const) && !any(!iszero, rrow)
161161
return const_val
162162
end
@@ -218,20 +218,32 @@ function tfunc(::Val{Core.Intrinsics.div_float}, @nospecialize(a::Union{Const, T
218218
return Incidence(builtin_math_tfunc(Core.Intrinsics.div_float, isa(a, Incidence) ? a.typ : a, widenconst(b.typ)), rrow)
219219
end
220220

221-
function tfunc(::Val{Core.Intrinsics.or_int}, @nospecialize(a::Union{Const, Type, Incidence}), @nospecialize(b::Union{Const, Type, Incidence}))
221+
function generic_math_twoarg(f, @nospecialize(a::Union{Const, Type, Incidence}), @nospecialize(b::Union{Const, Type, Incidence}))
222222
if isa(a, Const) && isa(b, Const)
223-
return builtin_math_tfunc(Core.Intrinsics.or_int, a, b)
223+
return builtin_math_tfunc(f, a, b)
224224
end
225225
if !isa(a, Incidence) && !isa(b, Incidence)
226-
return builtin_math_tfunc(Core.Intrinsics.or_int, a, b)
226+
return builtin_math_tfunc(f, a, b)
227227
end
228228
rrow = _zero_row()
229229
arow = isa(a, Incidence) ? a.row : _ZERO_ROW
230230
brow = isa(b, Incidence) ? b.row : _ZERO_ROW
231231
for i in Iterators.flatten((rowvals(arow), rowvals(brow)))
232232
rrow[i] = nonlinear
233233
end
234-
return Incidence(builtin_math_tfunc(Core.Intrinsics.or_int, widenconst(a), widenconst(b)), rrow)
234+
return Incidence(builtin_math_tfunc(f, widenconst(a), widenconst(b)), rrow)
235+
end
236+
237+
238+
function generic_math_onearg(f, @nospecialize(a::Union{Const, Type, Incidence}))
239+
if isa(a, Const) || !isa(a, Incidence)
240+
return builtin_math_tfunc(f, a)
241+
end
242+
rrow = _zero_row()
243+
for i in rowvals(a.row)
244+
rrow[i] = nonlinear
245+
end
246+
return Incidence(builtin_math_tfunc(f, widenconst(a)), rrow)
235247
end
236248

237249
function tfunc(::Val{Core.Intrinsics.and_int}, @nospecialize(a::Union{Const, Type, Incidence}), @nospecialize(b::Union{Const, Type, Incidence}))
@@ -292,46 +304,72 @@ is_any_incidence(@nospecialize args...) = any(@nospecialize(x)->isa(x, Incidence
292304
@nospecialize(f), argtypes::Vector{Any}, sv::Union{Compiler.AbsIntState,Nothing})
293305

294306
bargtypes = argtypes
307+
308+
if f === Core.getfield
309+
if length(argtypes) == 1 || length(argtypes) > 4
310+
return Union{}
311+
end
312+
313+
a = argtypes[1]
314+
b = argtypes[2]
315+
316+
if isa(a, Const)
317+
if isa(b, Const)
318+
return Compiler.getfield_tfunc(Compiler.typeinf_lattice(interp), a, b)
319+
elseif isa(b, Incidence)
320+
fT = Compiler.getfield_tfunc(Compiler.typeinf_lattice(interp), a, widenconst(b))
321+
fT === Union{} && return Union{}
322+
Base.issingletontype(fT) && return fT
323+
return Incidence(fT, copy(b.row))
324+
end
325+
return Compiler.getfield_tfunc(Compiler.typeinf_lattice(interp), a, b)
326+
elseif isa(a, Incidence)
327+
fT = Compiler.getfield_tfunc(Compiler.typeinf_lattice(interp), widenconst(a), b)
328+
fT === Union{} && return Union{}
329+
Base.issingletontype(fT) && return fT
330+
return Incidence(fT, copy(a.row))
331+
end
332+
return Compiler.getfield_tfunc(Compiler.typeinf_lattice(interp), a, b)
333+
end
334+
295335
if length(argtypes) == 1
336+
if f === Core.Intrinsics.have_fma
337+
return Incidence(Bool)
338+
end
296339
a = argtypes[1]
297340
if is_any_incidence(a)
298-
if f == Core.Intrinsics.neg_float
341+
if f == Core.Intrinsics.neg_float || f === Core.Intrinsics.neg_int
299342
return tfunc(Val(f), a)
343+
elseif f === Core.Intrinsics.ctlz_int || f === Core.Intrinsics.not_int || f === Core.Intrinsics.abs_float
344+
return generic_math_onearg(f, a)
300345
end
301346
end
302347
elseif length(argtypes) == 2
303348
a = argtypes[1]
304349
b = argtypes[2]
305350
if is_any_incidence(a, b)
306351
if (f == Core.Intrinsics.add_float || f == Core.Intrinsics.sub_float) ||
352+
(f == Core.Intrinsics.add_int || f == Core.Intrinsics.sub_int) ||
307353
(f == Core.Intrinsics.mul_float || f == Core.Intrinsics.div_float) ||
308354
f == Core.Intrinsics.copysign_float
309355
return tfunc(Val(f), a, b)
310-
elseif f == Core.Intrinsics.or_int
311-
return tfunc(Val(f), a, b)
312-
elseif f == Core.Intrinsics.and_int
313-
return tfunc(Val(f), a, b)
314-
elseif f == Core.Intrinsics.fptosi || f == Core.Intrinsics.sitofp
356+
elseif f in (Core.Intrinsics.or_int, Core.Intrinsics.and_int, Core.Intrinsics.xor_int, Core.Intrinsics.shl_int, Core.Intrinsics.lshr_int, Core.Intrinsics.flipsign_int)
357+
return generic_math_twoarg(f, a, b)
358+
elseif f == Core.Intrinsics.fptosi || f == Core.Intrinsics.sitofp || f == Core.Intrinsics.bitcast || f == Core.Intrinsics.trunc_int || f == Core.Intrinsics.zext_int || f == Core.Intrinsics.sext_int
315359
# We keep the linearity structure here and absorb the rounding error into be base Int64
316360
return Incidence(Compiler.conversion_tfunc(Compiler.typeinf_lattice(interp), widenconst(a), widenconst(b)), b.row)
317-
elseif f == Core.Intrinsics.lt_float || f == Core.Intrinsics.eq_float || f == Core.Intrinsics.slt_int
361+
elseif f == Core.Intrinsics.lt_float || f == Core.Intrinsics.ne_float || f == Core.Intrinsics.eq_float || f == Core.Intrinsics.slt_int || f == Core.Intrinsics.sle_int || f == Core.Intrinsics.ult_int || f == Core.Intrinsics.ule_int || f == Core.Intrinsics.eq_int || f == Base.:(===)
318362
r = Compiler.tmerge(Compiler.typeinf_lattice(interp), argtypes[1], argtypes[2])
319363
@assert isa(r, Incidence)
320364
return Incidence(Bool, r.row)
321-
elseif f === Core.getfield && isa(a, Incidence)
322-
a = argtypes[1]
323-
fT = Compiler.getfield_tfunc(Compiler.typeinf_lattice(interp), widenconst(a), argtypes[2])
324-
fT === Union{} && return Union{}
325-
Base.issingletontype(fT) && return fT
326-
return Incidence(fT, copy(a.row))
327365
end
328366
end
329367
elseif length(argtypes) == 3
330368
a = argtypes[1]
331369
b = argtypes[2]
332370
c = argtypes[3]
333371
if is_any_incidence(a, b, c)
334-
if f === Core.Intrinsics.muladd_float
372+
if f === Core.Intrinsics.muladd_float || f === Core.Intrinsics.fma_float
335373
# TODO: muladd vs fma here?
336374
if is_any_incidence(a, b)
337375
x = tfunc(Val(Core.Intrinsics.mul_float), a, b)

src/analysis/structural.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,9 @@ function _structural_analysis!(ci::CodeInstance, world::UInt)
279279
inst = ir[SSAValue(i)]
280280
stmt = inst[:stmt]
281281
stmt === nothing && continue
282+
# No need to process error paths - even if they were to contain intrinsics, such intrinsics would have
283+
# no effect.
284+
inst[:type] === Union{} && continue
282285
isexpr(stmt, :invoke) || continue
283286
is_known_invoke(stmt, variable, ir) && continue
284287
is_known_invoke(stmt, equation, ir) && continue

src/transform/tearing/schedule.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ function schedule_incidence!(compact, var_eq_matching, curval, incT::Incidence,
106106
continue
107107
end
108108
end
109-
109+
\
110110
acc = ir_mul_const!(compact, line, coeff, lin_var_ssa)
111111
curval = curval === nothing ? acc : ir_add!(compact, line, curval, acc)
112112
end
@@ -170,8 +170,10 @@ function schedule_nonlinear!(compact, param_vars, var_eq_matching, ir, val::Unio
170170
f = f.val
171171
@assert f in (Core.Intrinsics.sub_float, Core.Intrinsics.add_float,
172172
Core.Intrinsics.mul_float, Core.Intrinsics.copysign_float,
173-
Core.ifelse, Core.Intrinsics.or_int, Core.Intrinsics.and_int)
173+
Core.ifelse, Core.Intrinsics.or_int, Core.Intrinsics.and_int,
174+
Core.Intrinsics.fma_float, Core.Intrinsics.muladd_float)
174175
# TODO: or_int is linear in Bool
176+
# TODO: {fma, muladd}_float is linear in one of its arguments
175177
call_is_linear = f in (Core.Intrinsics.sub_float, Core.Intrinsics.add_float)
176178
end
177179

@@ -184,7 +186,7 @@ function schedule_nonlinear!(compact, param_vars, var_eq_matching, ir, val::Unio
184186

185187
# TODO: SICM
186188

187-
if !is_fully_state_linear(typ::Incidence, param_vars)
189+
if !is_const_plus_state_linear(typ::Incidence, param_vars)
188190
this_nonlinear = schedule_nonlinear!(compact, param_vars, var_eq_matching, ir, arg, ssa_rename; vars, schedule_missing_var!)
189191
else
190192
if @isdefined(result)
@@ -204,7 +206,7 @@ function schedule_nonlinear!(compact, param_vars, var_eq_matching, ir, val::Unio
204206
return schedule_incidence!(compact, var_eq_matching, this_nonlinear, typ, -1, inst[:line]; vars, schedule_missing_var!)[1]
205207
end
206208

207-
if is_fully_state_linear(incT, param_vars)
209+
if is_const_plus_state_linear(incT, param_vars)
208210
# TODO: This needs to do a proper template match
209211
ret = schedule_incidence!(compact, var_eq_matching, nothing, info.result.extended_rt, -1, inst[:line]; vars=
210212
[arg === nothing ? 0.0 : arg for arg in args[2:end]])[1]

test/regression.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,27 @@ end
4949
sol = solve(DAECProblem(tfb2_ipo, (1,) .=> 1.), IDA())
5050
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], exp.(sol.t)))
5151

52+
function tfb3()
53+
x = continuous()
54+
r = Base.fma_emulated(x,x,x)
55+
always!(ddt(x) - r)
56+
end
57+
sol = solve(DAECProblem(tfb3, (1,) .=> 1., (0, 0.1)), IDA())
58+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], -exp.(sol.t)./(exp.(sol.t) .- 2)))
59+
60+
function tfb4()
61+
x = continuous()
62+
r = Core.ifelse(Core.Intrinsics.have_fma(Float64), Base.fma_float(x,x,x), Base.fma_emulated(x,x,x))
63+
always!(ddt(x) - r)
64+
end
65+
sol = solve(DAECProblem(tfb4, (1,) .=> 1., (0, 0.1)), IDA())
66+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], -exp.(sol.t)./(exp.(sol.t) .- 2)))
67+
68+
function tfb5()
69+
x = continuous()
70+
always!(ddt(x) - log(1. + sim_time()))
71+
end
72+
sol = solve(DAECProblem(tfb5, (1,) .=> 1.), IDA())
73+
@test all(map((x,y)->isapprox(x[], y, atol=1e-2), sol.u[:, 1], (sol.t .+ 1) .* log.(sol.t .+ 1) .+ 1 .- sol.t))
74+
5275
end

0 commit comments

Comments
 (0)