@@ -10,7 +10,8 @@ kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwr
1010kgen (num:: Num , gradlist:: Vector{Num} , raw_outputs:: Vector{Symbol} ; constants:: Vector{Num} = Num[], overwrite:: Bool = false , splitting:: Symbol = :default , affine_quadratic:: Bool = true ) = kgen (num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic)
1111function kgen (num:: Num , gradlist:: Vector{Num} , raw_outputs:: Vector{Symbol} , constants:: Vector{Num} , overwrite:: Bool , splitting:: Symbol , affine_quadratic:: Bool )
1212 # Create a hash of the expression and check if the function already exists
13- expr_hash = string (hash (num+ sum (gradlist)), base= 62 )
13+ expr_hash = string (hash (string (num)* string (gradlist)), base= 62 )
14+ # expr_hash = string(hash(num+sum(gradlist)), base=62)
1415 if (overwrite== false ) && (isfile (joinpath (@__DIR__ , " storage" , " f_" * expr_hash* " .jl" )))
1516 try func_name = eval (Meta. parse (" f_" * expr_hash))
1617 return func_name
@@ -116,7 +117,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
116117 sparsity = detect_sparsity (factored, gradlist)
117118
118119 # Decide if the kernel needs to be split
119- if (n_vars[end ] < 31 ) && (n_lines[end ] <= max_size)
120+ if (n_vars[end ] < 31 ) && (( n_lines[end ] <= max_size) || ( findfirst (x -> x > split_point, n_lines) == length (n_lines)) )
120121 # Complexity is fairly low; only a single kernel needed
121122 create_kernel! (expr_hash, 1 , num, get_name .(gradlist), func_outputs, constants, factored, sparsity)
122123 push! (kernel_nums, 1 )
@@ -128,9 +129,14 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
128129 kernel_count = 1
129130 # structure_list = String[] # Experimental
130131 while ! complete
132+ # println("Kernel: $kernel_count")
133+ # for j in 1:length(n_lines)
134+ # println("$j : $(factored[j]), $(n_lines[j]), $(n_vars[j])")
135+ # end
136+ # println("")
131137 # Determine which line to break at
132138 line_ID = findfirst (x -> x > split_point, n_lines)
133- vars_ID = findfirst (x -> x == 31 , n_vars)
139+ vars_ID = findfirst (x -> ( x == 30 ) || (x == 31 ) , n_vars)
134140 if isnothing (vars_ID)
135141 new_ID = line_ID
136142 elseif isnothing (line_ID)
@@ -188,7 +194,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
188194 n_lines = complexity (factored)
189195 n_vars = var_counts (factored)
190196
191- # If the total number of lines (not including the final line) is below 2000
197+ # If the total number of lines (not including the final line) is below the max size
192198 # and the number of variables is below 32, we can make the final kernel and be done
193199 if (n_vars[end ] < 32 ) && (all (n_lines[1 : end - 1 ] .<= max_size))
194200 create_kernel! (expr_hash, kernel_count, extract (factored), get_name .(gradlist), func_outputs, constants, factored, sparsity)
@@ -328,7 +334,12 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
328334 file = open (joinpath (@__DIR__ , " storage" , " f_" * expr_hash* " .jl" ), " a" )
329335
330336 # Put in the preamble.
331- write (file, preamble_string (expr_hash, [" OUT" ; string .(vars)], 1 , 1 , length (gradlist)))
337+ if isempty (vars)
338+ write (file, preamble_string (expr_hash, [" OUT" ;], 1 , 1 , length (gradlist)))
339+ else
340+ write (file, preamble_string (expr_hash, [" OUT" ; string .(vars)], 1 , 1 , length (gradlist)))
341+ end
342+
332343
333344 # Depending on the format of the expression, compose the kernel differently
334345 if typeof (expr) <: Real
@@ -360,9 +371,9 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
360371 end
361372 end
362373 else # There must be two elements in the dictionary
363- binary_vars = string .(get_name .(keys (key . dict)))
374+ binary_vars = string .(get_name .(keys (expr . dict)))
364375 binary_vars = binary_vars[sort_vars (binary_vars)]
365- write (file, SCMC_quadaff_binary (vars ... , expr. coeff, varlist))
376+ write (file, SCMC_quadaff_binary (binary_vars ... , expr. coeff, varlist))
366377 end
367378
368379 elseif exprtype (expr)== ADD
@@ -394,7 +405,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
394405 # EAGO already does this and bypasses the need to calculate relaxations.
395406 # But, for compatibility with McCormick-style relaxations in ParBB,
396407 # it's easier to simply calculate what ParBB is expecting.)
397- write (file, postamble_quadaff (string .(vars), varlist))
408+ if isempty (varlist)
409+ write (file, postamble_quadaff (String[], String[]))
410+ elseif isempty (vars)
411+ write (file, postamble_quadaff (String[], varlist))
412+ else
413+ write (file, postamble_quadaff (string .(vars), varlist))
414+ end
398415 close (file)
399416
400417 # Include this kernel so SCMC knows what it is
@@ -403,7 +420,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
403420 # Add onto the file the "main" CPU function that calls the kernel
404421 blocks = Int32 (CUDA. attribute (CUDA. device (), CUDA. DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
405422 file = open (joinpath (@__DIR__ , " storage" , " f_" * expr_hash* " .jl" ), " a" )
406- write (file, outro (expr_hash, [1 ], [string .(vars)], [" OUT" ], blocks, get_name .(gradlist)))
423+ if isempty (gradlist)
424+ write (file, outro (expr_hash, [1 ], [String[]], [" OUT" ], blocks, Symbol[]))
425+ elseif isempty (vars)
426+ write (file, outro (expr_hash, [1 ], [String[]], [" OUT" ], blocks, get_name .(gradlist)))
427+ else
428+ write (file, outro (expr_hash, [1 ], [string .(vars)], [" OUT" ], blocks, get_name .(gradlist)))
429+ end
407430 close (file)
408431
409432 # Include the file again to get the final kernel
731754# 7) log(inv(x1)) = -log(x1) [EAGO paper]
732755# 8) CONST1*CONST2*x1 = (CONST1*CONST2)*x1
733756# 9) 1 / (1 + exp(-x)) = Sigmoid(x)
757+ # 10) sin(x) = cos(x - pi/2)
734758#
735759# Forms that aren't relevant yet:
736760# 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers)
@@ -826,7 +850,7 @@ function perform_substitutions(old_factored::Vector{Equation})
826850 end
827851 end
828852 # Create a factorization of this new expr
829- new_factorization = factor (new_expr)
853+ new_factorization = factor (new_expr, split_div = true )
830854 # Scan through the new factorization to see if we can merge elements
831855 # with the original factored list
832856 done = false
@@ -1191,7 +1215,7 @@ function perform_substitutions(old_factored::Vector{Equation})
11911215 new_expr *= arg
11921216 end
11931217 # Create a factorization of this new expr
1194- new_factorization = factor (new_expr)
1218+ new_factorization = factor (new_expr, split_div = true )
11951219
11961220
11971221 # Scan through the new factorization to see if we can merge elements
@@ -1315,6 +1339,38 @@ function perform_substitutions(old_factored::Vector{Equation})
13151339 end
13161340 end
13171341 end
1342+
1343+ # 10) sin(x) = cos(x - pi/2)
1344+ if exprtype (factored[index0]. rhs)== TERM
1345+ if factored[index0]. rhs. f== sin
1346+ # We found sin(arg). Check if (arg - pi/2) exists,
1347+ # and if so, also check if cos(arg - pi/2) exists.
1348+ scan_flag = true
1349+ index1 = findfirst (x -> isequal (x. rhs, arguments (factored[index0]. rhs)[] - pi / 2 ), factored)
1350+ if ! isnothing (index1)
1351+ index2 = findfirst (x -> isequal (x. rhs, cos (factored[index1]. lhs)), factored)
1352+ if ! isnothing (index2)
1353+ # cos(arg - pi/2) exists already (index2). Remove all reference to index0 and replace with index2
1354+ for i in eachindex (factored)
1355+ @eval $ factored[$ i] = $ factored[$ i]. lhs ~ substitute ($ factored[$ i]. rhs, Dict ($ factored[$ index0]. lhs => $ factored[$ index2]. lhs))
1356+ end
1357+ deleteat! (factored, index0)
1358+ else
1359+ # arg - pi/2 exists already (index1), but not cos(arg - pi/2). Change
1360+ # index0 to be cos of index1.lhs instead of sin of arg
1361+ @eval $ factored[$ index0] = $ factored[$ index0]. lhs ~ cos ($ factored[$ index1]. lhs)
1362+ end
1363+ else
1364+ # (arg - pi/2) doesn't exist, so we need to create it
1365+ newsym = gensym (:aux )
1366+ newsym = Symbol (string (newsym)[3 : 5 ] * string (newsym)[7 : end ])
1367+ newvar = genvar (newsym)
1368+ insert! (factored, index0, Equation (Symbolics. value (newvar), arguments (factored[index0]. rhs)[] - pi / 2 ))
1369+ @eval $ factored[$ index0+ 1 ] = $ factored[$ index0+ 1 ]. lhs ~ cos ($ newvar)
1370+ end
1371+ break
1372+ end
1373+ end
13181374 end
13191375 end
13201376
@@ -1511,6 +1567,8 @@ function write_operation(file::IOStream, RHS::BasicSymbolic{Real}, inputs::Vecto
15111567 write (file, SCMC_sigmoid_kernel (inputs... , gradlist, sparsity))
15121568 elseif RHS. f== sqrt
15131569 write (file, SCMC_float_power_kernel (inputs... , 0.5 , gradlist, sparsity))
1570+ elseif RHS. f== cos
1571+ write (file, SCMC_cos_kernel (inputs... , gradlist, sparsity))
15141572 else
15151573 close (file)
15161574 error (" Some function was used that we can't handle yet ($RHS )" )
@@ -1845,6 +1903,10 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
18451903 else
18461904 total_lines += 190
18471905 end
1906+ new_ID = findfirst (x -> isequal (x. lhs, RHS. base), factorized)
1907+ if ! isnothing (new_ID)
1908+ total_lines += _complexity (complexity, factorized, new_ID)
1909+ end
18481910 elseif exprtype (RHS) == TERM
18491911 if RHS. f== exp
18501912 total_lines += 212 # Ranges from 212--310
@@ -1866,6 +1928,16 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
18661928 end
18671929 elseif RHS. f== sqrt
18681930 total_lines += 190
1931+ new_ID = findfirst (x -> isequal (x. lhs, RHS. arguments[1 ]), factorized)
1932+ if ! isnothing (new_ID)
1933+ total_lines += _complexity (complexity, factorized, new_ID)
1934+ end
1935+ elseif RHS. f== cos || RHS. f== sin
1936+ total_lines += 300
1937+ new_ID = findfirst (x -> isequal (x. lhs, RHS. arguments[1 ]), factorized)
1938+ if ! isnothing (new_ID)
1939+ total_lines += _complexity (complexity, factorized, new_ID)
1940+ end
18691941 else
18701942 error (" Unknown function" )
18711943 end
0 commit comments