Skip to content

Commit 9945bee

Browse files
committed
Support for sin/cos
1 parent 4dfbc2b commit 9945bee

File tree

5 files changed

+1079
-18
lines changed

5 files changed

+1079
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SourceCodeMcCormick"
22
uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960"
33
authors = ["Robert Gottlieb <[email protected]>"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/kernel_writer/kernel_write.jl

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwr
1010
kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic)
1111
function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, constants::Vector{Num}, overwrite::Bool, splitting::Symbol, affine_quadratic::Bool)
1212
# Create a hash of the expression and check if the function already exists
13-
expr_hash = string(hash(num+sum(gradlist)), base=62)
13+
expr_hash = string(hash(string(num)*string(gradlist)), base=62)
14+
# expr_hash = string(hash(num+sum(gradlist)), base=62)
1415
if (overwrite==false) && (isfile(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl")))
1516
try func_name = eval(Meta.parse("f_"*expr_hash))
1617
return func_name
@@ -116,7 +117,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
116117
sparsity = detect_sparsity(factored, gradlist)
117118

118119
# Decide if the kernel needs to be split
119-
if (n_vars[end] < 31) && (n_lines[end] <= max_size)
120+
if (n_vars[end] < 31) && ((n_lines[end] <= max_size) || (findfirst(x -> x > split_point, n_lines)==length(n_lines)))
120121
# Complexity is fairly low; only a single kernel needed
121122
create_kernel!(expr_hash, 1, num, get_name.(gradlist), func_outputs, constants, factored, sparsity)
122123
push!(kernel_nums, 1)
@@ -128,9 +129,14 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
128129
kernel_count = 1
129130
# structure_list = String[] # Experimental
130131
while !complete
132+
# println("Kernel: $kernel_count")
133+
# for j in 1:length(n_lines)
134+
# println("$j : $(factored[j]), $(n_lines[j]), $(n_vars[j])")
135+
# end
136+
# println("")
131137
# Determine which line to break at
132138
line_ID = findfirst(x -> x > split_point, n_lines)
133-
vars_ID = findfirst(x -> x == 31, n_vars)
139+
vars_ID = findfirst(x -> (x == 30) || (x == 31), n_vars)
134140
if isnothing(vars_ID)
135141
new_ID = line_ID
136142
elseif isnothing(line_ID)
@@ -188,7 +194,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
188194
n_lines = complexity(factored)
189195
n_vars = var_counts(factored)
190196

191-
# If the total number of lines (not including the final line) is below 2000
197+
# If the total number of lines (not including the final line) is below the max size
192198
# and the number of variables is below 32, we can make the final kernel and be done
193199
if (n_vars[end] < 32) && (all(n_lines[1:end-1] .<= max_size))
194200
create_kernel!(expr_hash, kernel_count, extract(factored), get_name.(gradlist), func_outputs, constants, factored, sparsity)
@@ -328,7 +334,12 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
328334
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
329335

330336
# Put in the preamble.
331-
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist)))
337+
if isempty(vars)
338+
write(file, preamble_string(expr_hash, ["OUT";], 1, 1, length(gradlist)))
339+
else
340+
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist)))
341+
end
342+
332343

333344
# Depending on the format of the expression, compose the kernel differently
334345
if typeof(expr) <: Real
@@ -360,9 +371,9 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
360371
end
361372
end
362373
else # There must be two elements in the dictionary
363-
binary_vars = string.(get_name.(keys(key.dict)))
374+
binary_vars = string.(get_name.(keys(expr.dict)))
364375
binary_vars = binary_vars[sort_vars(binary_vars)]
365-
write(file, SCMC_quadaff_binary(vars..., expr.coeff, varlist))
376+
write(file, SCMC_quadaff_binary(binary_vars..., expr.coeff, varlist))
366377
end
367378

368379
elseif exprtype(expr)==ADD
@@ -394,7 +405,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
394405
# EAGO already does this and bypasses the need to calculate relaxations.
395406
# But, for compatibility with McCormick-style relaxations in ParBB,
396407
# it's easier to simply calculate what ParBB is expecting.)
397-
write(file, postamble_quadaff(string.(vars), varlist))
408+
if isempty(varlist)
409+
write(file, postamble_quadaff(String[], String[]))
410+
elseif isempty(vars)
411+
write(file, postamble_quadaff(String[], varlist))
412+
else
413+
write(file, postamble_quadaff(string.(vars), varlist))
414+
end
398415
close(file)
399416

400417
# Include this kernel so SCMC knows what it is
@@ -403,7 +420,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
403420
# Add onto the file the "main" CPU function that calls the kernel
404421
blocks = Int32(CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
405422
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
406-
write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist)))
423+
if isempty(gradlist)
424+
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, Symbol[]))
425+
elseif isempty(vars)
426+
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, get_name.(gradlist)))
427+
else
428+
write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist)))
429+
end
407430
close(file)
408431

409432
# Include the file again to get the final kernel
@@ -731,6 +754,7 @@ end
731754
# 7) log(inv(x1)) = -log(x1) [EAGO paper]
732755
# 8) CONST1*CONST2*x1 = (CONST1*CONST2)*x1
733756
# 9) 1 / (1 + exp(-x)) = Sigmoid(x)
757+
# 10) sin(x) = cos(x - pi/2)
734758
#
735759
# Forms that aren't relevant yet:
736760
# 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers)
@@ -826,7 +850,7 @@ function perform_substitutions(old_factored::Vector{Equation})
826850
end
827851
end
828852
# Create a factorization of this new expr
829-
new_factorization = factor(new_expr)
853+
new_factorization = factor(new_expr, split_div=true)
830854
# Scan through the new factorization to see if we can merge elements
831855
# with the original factored list
832856
done = false
@@ -1191,7 +1215,7 @@ function perform_substitutions(old_factored::Vector{Equation})
11911215
new_expr *= arg
11921216
end
11931217
# Create a factorization of this new expr
1194-
new_factorization = factor(new_expr)
1218+
new_factorization = factor(new_expr, split_div=true)
11951219

11961220

11971221
# Scan through the new factorization to see if we can merge elements
@@ -1315,6 +1339,38 @@ function perform_substitutions(old_factored::Vector{Equation})
13151339
end
13161340
end
13171341
end
1342+
1343+
# 10) sin(x) = cos(x - pi/2)
1344+
if exprtype(factored[index0].rhs)==TERM
1345+
if factored[index0].rhs.f==sin
1346+
# We found sin(arg). Check if (arg - pi/2) exists,
1347+
# and if so, also check if cos(arg - pi/2) exists.
1348+
scan_flag = true
1349+
index1 = findfirst(x -> isequal(x.rhs, arguments(factored[index0].rhs)[] - pi/2), factored)
1350+
if !isnothing(index1)
1351+
index2 = findfirst(x -> isequal(x.rhs, cos(factored[index1].lhs)), factored)
1352+
if !isnothing(index2)
1353+
# cos(arg - pi/2) exists already (index2). Remove all reference to index0 and replace with index2
1354+
for i in eachindex(factored)
1355+
@eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $factored[$index2].lhs))
1356+
end
1357+
deleteat!(factored, index0)
1358+
else
1359+
# arg - pi/2 exists already (index1), but not cos(arg - pi/2). Change
1360+
# index0 to be cos of index1.lhs instead of sin of arg
1361+
@eval $factored[$index0] = $factored[$index0].lhs ~ cos($factored[$index1].lhs)
1362+
end
1363+
else
1364+
# (arg - pi/2) doesn't exist, so we need to create it
1365+
newsym = gensym(:aux)
1366+
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
1367+
newvar = genvar(newsym)
1368+
insert!(factored, index0, Equation(Symbolics.value(newvar), arguments(factored[index0].rhs)[] - pi/2))
1369+
@eval $factored[$index0+1] = $factored[$index0+1].lhs ~ cos($newvar)
1370+
end
1371+
break
1372+
end
1373+
end
13181374
end
13191375
end
13201376

@@ -1511,6 +1567,8 @@ function write_operation(file::IOStream, RHS::BasicSymbolic{Real}, inputs::Vecto
15111567
write(file, SCMC_sigmoid_kernel(inputs..., gradlist, sparsity))
15121568
elseif RHS.f==sqrt
15131569
write(file, SCMC_float_power_kernel(inputs..., 0.5, gradlist, sparsity))
1570+
elseif RHS.f==cos
1571+
write(file, SCMC_cos_kernel(inputs..., gradlist, sparsity))
15141572
else
15151573
close(file)
15161574
error("Some function was used that we can't handle yet ($RHS)")
@@ -1845,6 +1903,10 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
18451903
else
18461904
total_lines += 190
18471905
end
1906+
new_ID = findfirst(x -> isequal(x.lhs, RHS.base), factorized)
1907+
if !isnothing(new_ID)
1908+
total_lines += _complexity(complexity, factorized, new_ID)
1909+
end
18481910
elseif exprtype(RHS) == TERM
18491911
if RHS.f==exp
18501912
total_lines += 212 # Ranges from 212--310
@@ -1866,6 +1928,16 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
18661928
end
18671929
elseif RHS.f==sqrt
18681930
total_lines += 190
1931+
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
1932+
if !isnothing(new_ID)
1933+
total_lines += _complexity(complexity, factorized, new_ID)
1934+
end
1935+
elseif RHS.f==cos || RHS.f==sin
1936+
total_lines += 300
1937+
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
1938+
if !isnothing(new_ID)
1939+
total_lines += _complexity(complexity, factorized, new_ID)
1940+
end
18691941
else
18701942
error("Unknown function")
18711943
end

0 commit comments

Comments
 (0)